{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 的士调度 Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import gym"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 环境使用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "观察空间 = Discrete(500)\n",
      "动作空间 = Discrete(6)\n",
      "状态数量 = 500\n",
      "动作数量 = 6\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "print('观察空间 = {}'.format(env.observation_space))\n",
    "print('动作空间 = {}'.format(env.action_space))\n",
    "print('状态数量 = {}'.format(env.observation_space.n))\n",
    "print('动作数量 = {}'.format(env.action_space.n))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 1 1 2\n",
      "的士位置 = (0, 1)\n",
      "乘客位置 = (0, 4)\n",
      "目标位置 = (4, 0)\n",
      "+---------+\n",
      "|R:\u001b[43m \u001b[0m| : :\u001b[34;1mG\u001b[0m|\n",
      "| : | : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "state = env.reset()\n",
    "taxirow, taxicol, passloc, destidx = env.unwrapped.decode(state)\n",
    "print(taxirow, taxicol, passloc, destidx)\n",
    "print('的士位置 = {}'.format((taxirow, taxicol)))\n",
    "print('乘客位置 = {}'.format(env.unwrapped.locs[passloc]))\n",
    "print('目标位置 = {}'.format(env.unwrapped.locs[destidx]))\n",
    "env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(126, -1, False, {'prob': 1.0})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "env.step(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---------+\n",
      "|R: | : :\u001b[34;1mG\u001b[0m|\n",
      "| :\u001b[43m \u001b[0m| : : |\n",
      "| : : : : |\n",
      "| | : | : |\n",
      "|\u001b[35mY\u001b[0m| : |B: |\n",
      "+---------+\n",
      "  (South)\n"
     ]
    }
   ],
   "source": [
    "env.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SARSA 算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSAAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.2, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = self.q[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done, next_action):\n",
    "        u = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_sarsa(env, agent, train=False, render=False):\n",
    "    episode_reward = 0\n",
    "    observation = env.reset()\n",
    "    action = agent.decide(observation)\n",
    "    while True:\n",
    "        if render:\n",
    "            env.render()\n",
    "        next_observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        next_action = agent.decide(next_observation) # 终止状态时此步无意义\n",
    "        if train:\n",
    "            agent.learn(observation, action, reward, next_observation,\n",
    "                    done, next_action)\n",
    "        if done:\n",
    "            break\n",
    "        observation, action = next_observation, next_action\n",
    "    return episode_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 834 / 100 = 8.34\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZxU1Zn/8c/TK1uzg6xtN7ssitgC4q7IZiLBxITEGGMmQY04GSf+DMSZmOhgTGLiL5moGSYhGeeXSIyJA6NEhCRGsxjAiAgi2iIKsi+yN013P78/6lZT3V3VC1XdtX3fr1e9+ta599Y9p6r6qVPPuXWPuTsiIpJdcpJdARERaXsK/iIiWUjBX0QkCyn4i4hkIQV/EZEslJfsCjRHz549vaSkJNnVEBFJKy+//PJed+8VbV1aBP+SkhLWrFmT7GqIiKQVM3s31jqlfUREspCCv4hIFlLwFxHJQgr+IiJZSMFfRCQLKfiLiGQhBX8RkSyk4B+HXYcqOF5ZzfYPjnOiqhqA45XV7DpUEXOfwxUn2XvkRO39jTsO8fK7++tsc/DYSQ4crWT7B8epOFlNTY3z7r6jxLr89vHKap5YsxV3x9351ZqttfUBeHb9DvYcPnXMLXuPsuy1Hew/WsnmPUd4dv1Olqx9v0VtrzgZandTjlVWsfPgqeejusZZvOq9OvUDeHffUWpqTrXv4LGTDZ6Xpo7T2PMe9sy6Hew9coKKk9Xc/vgrbN5zJOp2lVU1/HL1e1RV17Dy9V28tHkfAPuOnODg8ZNNHif82tW3df8xqqprmtzf3Xln79Ha+zU1TvnuI2zdf6zJfcPeD+rw5MvbOF5Z3eAxm6s5bd5/tJKDx05t88GxSv731e2193ccPPV8RNbjvX3H+OObe2I+7p7DJzhcEfvY4ceprnE27jhU+17bd+QET6zeys6DFby379RztuPgcY5XVvOzP7/D1v3H2BLl+SjffZi/lO/lnb2n/uf+9NZeTjbxurl71MeLtPtQBUdPVEVdt3HHIdZs2V9b/6fXbY+6XaJYOlzPv6yszFPxR14l856pXb5wSA/+8vY+wk/nLZcOZsna99kRvBmvPrsv/2fKcK754Z84VFHFN68dw7v7jvGjP74NwD9fNYw/vrmH/l3bs/TVui96944F7D9aCcCcSwbRuV0eDz73JgDzp49gw/ZDDfYBuGHimZT27Mi9T78OQLv8HCpONh14AC4e2pMX39rLjDF9WPbaTh65fhy3P/4K35w1hrt+va7OtjkG4bg9ok8Re49UcqKqmsMVp97k//qhkTz8h/Ladpxf0o2t+4+z81AFRe3y6mxbvx4fO28Az67fSfuCXH7z91MfUgO7t+ehj4/ljifWsnV/6IOotGdHcnOM8t2ngvodk4fx0Mo36zzu+NLurHon9I/23B2XYMBVD73Q6HMS+fz99z+M56tPvVZ73Eg3XVjCT/+8BYDn77yMXYcquOEnq6iMCB7P/tPFvLnrCP/4+Cv89KbzeXnLAT5zwZl8+VevsmH7odrnCeCR68fx3ec28faeUGB59Ppx3Przvzc47rc/ejbnDOzKvU9v4MNn92Peb16L2ZbHPjeezyxaxdkDujB9dF8mDurOkrXb+dlftjBtVB+e3bATgIK8HCqrTtW7IDeHSUN68PymUwH72nP785tXQq/LuOKufP7iQXwxon6XDOvFC2/uYdLgHgzu1Ykn1mzlRFXD92Gnwjzuv3YMP/nTO7y69QMuH96LPwTHGdGniDN7dGD5hl0ALJg1miWvbGfVloYdhHs+PJJv/O/rdcouGtKTGyeV8IXHoseRnp0Kqa6p4eiJ6jqv08yx/Th4/GRte//1QyNZ8fpOtn9QwXv7j3F+STfa5efy4lt7a/e5fHgvpo/u2+D/ZHxJ96j1/eT4Yh5f9V7M+v/HDecxdVSfqPVuipm97O5lUdclK/ib2TTg+0Au8GN3fyDWtqkW/DfuOMSIPkWUzl+W7KqISBbY8sDVp7VfY8E/KWkfM8sFHgamAyOBT5rZyGTUpaVefvcA07//Ij9+8Z1kV0VE5LQlK+c/Hih3983uXgksBmYmqS51LHjm9dpUzF1PvkrJvGdYsvZ93tx1GAjlUQFeeCt2nlIS7xNlA5NdhRb74mWDk12FhBnZt3Oj60f168yD150DhFJ6zXXrZYMZ1LMjpT07cv+sMUw+6wyuHNG7wXZDe3fiyhG9ueacflwyLOp1yppt5th+ce1fX7jd0XTvWFC7POeSQXXWNfacfujsvrXLP/zUuXHULrakpH3M7GPANHf/fHD/BmCCu8+N2GYOMAeguLj4vHffjXl9ooQK5/G3PHB1nZx+uOyZdTu47RcN863SMq9+bQrn3PtczPX3zxrDw38o52R1DU/eMoniHh2oOFnNV369jk07D/PGzsP0Lipkd8RAdjRv3z+D3Bzj60s38LO/bKktf+3rU+hQkMfgr4ZSd5vvn8HGnYe49f/9nfciBlXb5+cyql9nPnH+QK4Z24/vrXiTNVsO8PK7B2q3GV/aneoa5+V3DzC4V0dmndufLu3zueGCktr30DeuGcU9SzcA8PPPT+D6H/8NgP5d2/N/Z4/lZ3/ewjOv7WhQ/y9eNphHng91RrY8cDUHjlZy7n0ratcP7N6e788+l2sf+UvM52DjvdOoOFldZ79wm49WVvG7jbvpVVTIgWOVzP3FK0Aox/7GzlCH5zdfnMS44m6U7z7M5O+9UFuXR59/m3HFXTlw7CRTR52BmdU+drjd55d0Y/WWA/zHDedR1C6Pu55cx7YDp8ZJnvriJM4tbvhhEd7/z/Ou4N19R5k0uGftusqqGq7+wYscOVHFi3ddzitbP+C6H/0VgCG9O7Hyny+lqrqmdlzhyIkqJtz/O64Y0ZtHrh9Hu/xcjpyoYvQ9y4HQmMU914yiS/v8Bsf/6/wruOCbv+eac/rx4HXn8MvV73GooorbLh/Ce/uOcexkFSP6dG4QKwA6FOTyyteuYvi/PFv7nL237xjHT1YzvE9R1NfqxkWr+OObe9jwjal0LIz/upuNpX2SdVVPi1JW51PI3RcCCyGU82+LSkVa//7BqOU50Woutdrn53I8OKvjw+f049ZLB9O9YwETv/m7Ott16ZDP07dfxIf+/U8AzDq3PzdfOoh//305E0u786kJxXxqQnGdfdrl5/L92efyyPPlvPHsJm6cVMKRE1V065DP/cve4Oqz+3LgaCWPXD+OZ9fvZHT/LuQGL9hXZ5zFlFFn0L9re9yhqF1+ncfOyTFG9evCC3ddztt7jtAuP5ffb9zFBYN7MKT3qX/U+dPPYvfhCh5a8Rafu7CEAd06UJiXw7r3D/KRh/9Mh4I85l4xtHb78IDpjZNKGN6niLP6dKZLh3yeu+MSitrl0bdLewB6FxXywpt7OHtgF64e04/hfYp4b/9RZp07gOmj+7JxxyEAOrfP55yBXbl+QjF3PbmOz04qZVxxN37x+Qm8vecI/7ok9AHz0XED+PXftzF/+gjaF+SSnxt6Hm6+dBDvHzhO2ZndyMkxitrl85Fz+9fWd+qoPtS481YwGP0/cy+kc+1zVffNf2sj32zCg5jzZ5zF/F+/xsVDe9KhII/n77yMv72zn6fX7eDxVe/RrUNB1P2/9dExPL8pdAJE/67t66wryMthxT9fWnv//JLuLJg1mrufWs8nx4feM3m5OeTlhhIbHQvz2HjvNNrl59R+QHXIz2V0/85UVtXwtQ+PrBP4I/Xt0p5f3XIBo/t1oSAvhxsuKKldV9yjQ+3yJcN6cc05/RjZtzO9igr57fodTBrcg8K8XODUt43IfaL54afO5a3dRxIS+JuSrJ7/BcDX3X1qcH8+gLt/M9r2bTngG/4Ev2/mqNp/pLAtD1zNHb9cy1OvtOy0yFR262WDeTToWUaz6d+m1fZcYunRsYB9wdkpWx64mvXvH6Rf1/a1X3l3Hapgwv2ngn/fLu346/wrAXhz12Fe3foB17UgrVNxspqFL2zm5ksHUZiXy4GjlVz/47/x8PXjKO3ZsdmPA/CfL2xm24FjfGPm6BbtV9+WvUe57MHn+cwFZ3JvxGPtPlzBoeMn63yAJEplVQ35uVanx11ZVYMZ5OUYVTVOXs6p9Sera+rcb6ndhyoYf//vuHZcf7738bGNbltT41S7k58bPbNccbKaDdsPcd6ZzU8RNSXa83G6ntuwk64dChhf2j3uxzpZXUOuGTlJ6DmmYs9/NTDUzEqB94HZwKeSVJeoaqJ8JlacrM6owA9w55ThPPr828y9fAg//EM5cCpwABTm5dKzUwF7j1TGfIzwm/pLV4Z6vKP7d6mzPvy/2KNjAU/ccgE9IvKgw84oYtgZLQuM7fJz+ccrT/Wuu3UsYNmXLm7RY4R9oV4e9nSV9OzI0rkXMqJP3Txu76J29C5ql5Bj1FeQ1zCwRpaFe/un7sc3xNe7czuevv0ihvTu1OS2OTlGTtQv+CHt8nMTGvgh+vNxuqac5qmV0cT7vLeWpNTK3auAucByYCPwhLtvaHyvthXtG9FF3/p9EmrSPINa2OMNy80xtjxwNXdOHV5bVn7/jDrbhD8IfvTpccwY0/CfIvxcXV8vTRMWThvcfOkgBvfqRNcYX/XT3dkDuiY0AKWi0f270C4/N9nVkARI2jvV3Ze5+zB3H+zuC5JVj1iiJcMa6/0m2+KbJ3LdeQNa5bGrq0PPxgWDevLI9ec1XB98OOTG+FrbLj+XLQ9czZxLMufsF5F0l9ndlDikwQ+f6+hd1I7vXHcO678xNeGPHe755wVphEWfrZtCbCr4i0jqUfCPIZ1i/x2Th9Uud2rBWQI/u+n8mOt+9+VLWTxnIgDTRodSPeGUxuXDe3NzRK68tFcoB5yquU0RaSjr/1sPV5zkv/+6pUGOPx2ueRSWl9t4j7sgCMp/mXdFbdn9s8Zw2fCGP6YJG9yrExMH9QDgWx89m1V3X1kb3M2M+TPOqt32p589n0WfLWuT09NEJDGy/r/1niUb+M0r7zOkdxEXDO5RW75kbeteUa8trb3nKmq87reC+ufQN6YgL6fRM1a6dyzgihFnxFVHEWlbWd/zD5+fXnGyus4leF+L8SOvthb+mXfk6ZFhkyI+rKJ5+vaLePGuy+lQkFcb+Lt2iP5jFhHJLlnf84/8Pcjrwa8oU8m86SN4et2OBqfXFbXL45yBXfnL2/sa7HPzJYOoqvEG59sDrLjj0jrX149UkJtT51okTSkqzONwjGuTi0hqy/rgn+raB0F/7MCulJV0Y8na7dxy6WD+4aJSFv05+pVFI/Px9fUqKqRXUWHUdRvundrIz3Ia+v2dl7HvaOPX1hGR1JT1aZ9wz/lQxck6M0mlih6dClly24U8eN05tb86PqtvUcwAHo/8iOuhNEevosIGv2gVkfSQ9T3/8CxFX1q8ljFR0iSp4JyBXQH48lXD2PHBca6IcslbEZGWyOrgf9NPV9W5nyqDvLGU9OzIk7dOSnY1RCQDZHXa5w+bUmtClvpTtcU7aYWISCxZGfzf2XuUF9NgJq6vzhjR6Pqrx4ROA518ls6xF5GWycq0z+UPPp/sKiTE6P5dTntiZxHJblnZ8xcRyXYK/iIiWSiu4G9m15nZBjOrMbOyeuvmm1m5mW0ys6kR5dOCsnIzmxfP8UVE5PTE2/NfD1wLvBBZaGYjCU3NOAqYBjxiZrlmlgs8DEwHRgKfDLYVEZE2FNeAr7tvBKJNmDwTWOzuJ4B3zKwcGB+sK3f3zcF+i4NtX4+nHulu1rn9U/YHZiKSmVrrbJ/+wEsR97cFZQBb65VPiPYAZjYHmANQXNz8yw+no4c+MTZqeV6OhmREpHU0GfzNbCUQbSr7u919SazdopQ50dNMUS+o4+4LgYUAZWVlqXfRnVZ2fkk3hvTulOxqiEiGajL4u/vk03jcbcDAiPsDgPDsKLHKJcKnJ56Z7CqISAZrrbzCUmC2mRWaWSkwFFgFrAaGmlmpmRUQGhRe2kp1iOo//vh2Wx6uSZ8oGxi1fFxxtzauiYhkk7hy/mY2C/h3oBfwjJmtdfep7r7BzJ4gNJBbBdzm7tXBPnOB5UAusMjdN8TVghb65m/faMvDNem+j4yOWj6we4c2romIZJN4z/Z5CngqxroFwIIo5cuAZfEcN5M0PFFKRKT16XSSJFPsF5FkUPAXEclCCv5JFuUHciIirU7BX0QkCyn4J5n6/SKSDAr+bahbh/xkV0FEBFDwTzql/EUkGRT8k0wDviKSDAr+IiJZSMG/DWXdpUlFJGUp+LeB268YwtO3X5TsaoiI1GqtyVwkwpenDE92FURE6lDPX0QkCyn4i4hkIQV/EZEspOAvIpKF4gr+ZvYdM3vDzNaZ2VNm1jVi3XwzKzezTWY2NaJ8WlBWbmbz4jl+uvvOx85OdhVEJEvF2/NfAYx297OBN4H5AGY2ktD8vKOAacAjZpZrZrnAw8B0YCTwyWDbrHRdjPl7RURaW1zB392fc/eq4O5LwIBgeSaw2N1PuPs7QDkwPriVu/tmd68EFgfbtqrfbdxFybxnWPhCak3eLiKSLInM+X8O+G2w3B/YGrFuW1AWq7wBM5tjZmvMbM2ePXviqth3lm8C4P5lqTN5+xcvG5zsKohIFmvyR15mthLoE2XV3e6+JNjmbqAK+Hl4tyjbO9E/bKJe9cDdFwILAcrKyjLuygifnVSS7CqISBZrMvi7++TG1pvZjcCHgCvdPRyktwGRCe0BwPZgOVa5iIi0kXjP9pkGfAW4xt2PRaxaCsw2s0IzKwWGAquA1cBQMys1swJCg8JL46lDoowv7d6mx8u4rzIiklbivbbPD4FCYEVwXfqX3P0Wd99gZk8ArxNKB93m7tUAZjYXWA7kAovcfUOcdYjb3MuH8Pf3DiS7GiIibSbes32GuPtAdx8b3G6JWLfA3Qe7+3B3/21E+TJ3HxasWxDP8RNhwazR3Dl1OLk5zZ9U5eFPjWvFGomItL6s/4VveJSifvDv0bEg5j4j+3VuzSqJiLS6rAj+3kiCPbwqt950ij+7aXzMfUp7dqSrJmMXkTSWFcG/MeETlOr3/JuaWjc/t+VP3SMR6aLGPpBERFpb1gf/sPrB3x16doqd+jkdk4b0pHdRYUIfU0TkdGR98A/3wM/o3C7K2uYPAouIpJOsD/41QfSfN30ED1w7htH9Q4O53sSZ+ErbiEg6y/rgH9YuP5fZ44uxOr19RXgRyUxZEfxjDd5+akIxs88vTuhjioikg6wI/tFSNB0Lcrl/1hjaF+Q2ss+pCH9+SbfWqZyISBJkRfCP5uZLo19S+a5pw+naIZ8hvTvVKbcEDf7qG4OIpIKsDf63XzEkavnFQ3ux9mtT6FiYR2M5/6J2p3dZpOvOGxjX/iIiiZC1wd9a2gWvt/l/NfIL4MZ8ecow3rhvWvDhIiKSHFkb/Fuq/kfFwO4d6tz/0pVDm/c4ZrTLjz7OICLSVhT8G3Uq5Df1ReG2y4dQdqYGhUUkPSj4N+pUzj9RA74iIqkg3pm87jOzdWa21syeM7N+QbmZ2Q/MrDxYPy5inxvN7K3gdmO8DWgrTfX8dRaPiKSTeHv+33H3s919LPA08LWgfDqhqRuHAnOARwHMrDtwDzABGA/cY2ZpkStpMvg3YxsRkVQR70xehyLuduRUnmQm8JiHvAR0NbO+wFRghbvvd/cDwApgWjx1aF0ROf8m0j5mpuv9iEjaiPt8QzNbAHwGOAhcHhT3B7ZGbLYtKItVHu1x5xD61kBx8eldgiGR1KsXkUzSZM/fzFaa2foot5kA7n63uw8Efg7MDe8W5aG8kfKGhe4L3b3M3ct69erVvNYkXPO78kr7iEg6abLn7+6Tm/lYvwCeIZTT3wYMjFg3ANgelF9Wr/z5Zj5+SlPgF5F0Eu/ZPpG/bLoGeCNYXgp8JjjrZyJw0N13AMuBKWbWLRjonRKUpb0W/2JYRCSJ4s35P2Bmw4Ea4F3glqB8GTADKAeOATcBuPt+M7sPWB1sd6+774+zDq2oZQFdA74iki7iCv7u/tEY5Q7cFmPdImBRPMdtO4rmIpKZ9AvfBFLmR0TShYK/iEgW0nWFG9X2XfmX/2UyNco2iUgrU/BvVNtH4R6dCtv8mCKSfZT2ERHJQgr+rWBk386cM7BrsqshIhJTVgR/b+X0zfzpI+rc//o1o1hy24WtekwRkXhkRfBvbTdfOjjZVRARaZGsCP6ahUtEpK6sCP6tnfYREUk32RH8FftFROrIiuB/+pddOL0dXZ82IpLisiP4n3bOX0FcRDJTxgf/la/vYtOuw216TF3bX0RSXcYH/x//aXOyqyAiknISEvzN7E4zczPrGdw3M/uBmZWb2TozGxex7Y1m9lZwuzERx2896sGLSGaK+8JuZjYQuAp4L6J4OjA0uE0AHgUmmFl3QnP8lhFKqL9sZkvd/UC89YhZv7gC+Onl/DXgKyKpLhFX9XwIuAtYElE2E3gsmNHrJTPramZ9CU3eviI8daOZrQCmAY8noB5t7hdfmMDW/ceSXQ0RkRaLK/ib2TXA++7+ar1Bzv7A1oj724KyWOWtpjXHXicN7gm6soOIpKEmg7+ZrQT6RFl1N/BVYEq03aKUeSPl0Y47B5gDUFxc3FQ1U4IuIyEi6aLJ4O/uk6OVm9kYoBQI9/oHAH83s/GEevQDIzYfAGwPyi+rV/58jOMuBBYClJWVpUUSXZeREJF0cdpn+7j7a+7e291L3L2EUGAf5+47gaXAZ4KzfiYCB919B7AcmGJm3cysG6FvDcvjb0ZsOuVeRKSh1prGcRkwAygHjgE3Abj7fjO7D1gdbHdvePBXRETaTsKCf9D7Dy87cFuM7RYBixJ1XBERabmM/4VvWw7CasBXRNJFxgf/tlR/wPfTE4vp37V9kmojIhJba+X8Bfi3j4xJdhVERKLK+J5/os/26dmpILEPKCKSBOr5t8Dr904lR+eOikgGUPBvgQ4FjT9dGvAVkXSR8WmftqRf+IpIusj44K9ZtUREGsr44N+WlPYRkXSh4C8ikoUyPvirLy4i0lDGB/+2pAFfEUkXCv4iIlko44N/PCf7TB/dt2XHUpJJRNJExgd/jyMTc8+HR/L92WMTVxkRkRSR+cE/jn3zcnPo1iF0LZ94PkRERFJNXMHfzL5uZu+b2drgNiNi3XwzKzezTWY2NaJ8WlBWbmbz4jm+iIicnkRc2+chd38wssDMRgKzgVFAP2ClmQ0LVj8MXEVozt/VZrbU3V9PQD2i8gR12fVDYRHJJK11YbeZwGJ3PwG8Y2blwPhgXbm7bwYws8XBtq0W/BNFaR8RySSJyPnPNbN1ZrbIzLoFZf2BrRHbbAvKYpU3YGZzzGyNma3Zs2dPAqopIiJhTQZ/M1tpZuuj3GYCjwKDgbHADuC74d2iPJQ3Ut6w0H2hu5e5e1mvXr2a1ZjWpLSPiGSSJtM+7j65OQ9kZv8JPB3c3QYMjFg9ANgeLMcqT2lK+4hIJon3bJ/IX0HNAtYHy0uB2WZWaGalwFBgFbAaGGpmpWZWQGhQeGk8dWiKgraISEPxDvh+28zGEkrdbAFuBnD3DWb2BKGB3CrgNnevBjCzucByIBdY5O4b4qxDoxJ1vR2lfUQkk8QV/N39hkbWLQAWRClfBiyL57ipTl82RCTVZf4vfBWJRUQaUPAXEclCGR/8RUSkoYwP/okaqG3JNwiNDYtIqsv44J+MtI8yTSKS6jI/+OtUTxGRBjI++CeKBo5FJJNkfPCPFrRX392sK1a0nL4diEiayPzgH6WsV1Fhix+nWWkffTsQkTSR8cE/UZT2EZFMouAvIpKFMj/4J6jHrrN9RCSTZH7wT5BmpX30ASEiaSLjg3+85/m3qMevcQERSROZH/zjDMga6BWRTBR38Dez281sk5ltMLNvR5TPN7PyYN3UiPJpQVm5mc2L9/gpRWkfEUkTcU3mYmaXAzOBs939hJn1DspHEpqicRTQD1hpZsOC3R4GriI0z+9qM1vq7q/HU4/GxNtx10CviGSieKdxvBV4wN1PALj77qB8JrA4KH/HzMqB8cG6cnffDGBmi4NtWy34x0tpHxHJRPGmfYYBF5vZ38zsj2Z2flDeH9gasd22oCxWeWbQB4WIpIkme/5mthLoE2XV3cH+3YCJwPnAE2Y2iOjZbyf6h03UkGlmc4A5AMXFxU1VMyaPs+uutI+IZKImg7+7x7wKmpndCvzGQxF2lZnVAD0J9egHRmw6ANgeLMcqr3/chcBCgLKysqT1qVv02aEPChFJE/Gmff4HuAIgGNAtAPYCS4HZZlZoZqXAUGAVsBoYamalZlZAaFB4aZx1EBGRFop3wHcRsMjM1gOVwI3Bt4ANZvYEoYHcKuA2d68GMLO5wHIgF1jk7hvirEOjdLaPiEhDcQV/d68EPh1j3QJgQZTyZcCyeI7bEjpbR0SkoYz/hW+8WvLh8ZVpwxnYvT1j+ndpvQqJiCRAvGmflNeWHf/zzuzOi3dd0YZHFBE5PZnf89epniIiDWR+8I8zemvMQEQyUeYHf0VvEZEGMj7461RPEZGGMj74x0tfHEQkE2V88FfwFhFpKPODf1tO4ygikiYyPviLiEhDCv4iIllIwV9EJAtlfPDXgK+ISEMZH/x1kTURkYYyPvgX9+iQ7CqIiKScjA/+IiLSUFzB38x+aWZrg9sWM1sbsW6+mZWb2SYzmxpRPi0oKzezefEcvzmU8xcRaSjembw+EV42s+8CB4PlkYTm5x0F9ANWBnP8AjwMXEVokvfVZrbU3V+Ppx4iItIyCZnMxcwM+DjBZO7ATGCxu58A3jGzcmB8sK7c3TcH+y0OtlXwFxFpQ4nK+V8M7HL3t4L7/YGtEeu3BWWxyhswszlmtsbM1uzZs+e0KnXgaCXfWb7ptPYNyw2u75Cfq+s8iEjmaLLnb2YrgT5RVt3t7kuC5U8Cj0fuFmV7J/qHTdSsvLsvBBYClJWVJS1zP3FQD269bDA3XViSrCqIiCRck8Hf3Sc3tt7M8oBrgfMiircBAyPuDwC2B8uxyhMuJyf+3npOjvGVaSMSUBsRkdSRiLTPZOANd98WUbYUmG1mhWZWCgwFVgGrgaFmVmpmBYQGhZcmoD4Jj0MAAAbtSURBVA5R5SYg+IuIZKJEDPjOpm7KB3ffYGZPEBrIrQJuc/dqADObCywHcoFF7r4hAXWIKk/BX0QkqriDv7t/Nkb5AmBBlPJlwLJ4j9scOboYv4hIVBn9C1+lfUREosvo4K/YLyISXUYHf1PaR0QkqowO/iIiEp2Cv4hIFlLwFxHJQlkT/HsVFSa7CiIiKSNrgv/quxu9SoWISFbJmuAvIiKnKPiLiGQhBX8RkSyUkJm8UtmWB65OdhVERFKOev4iIllIwV9EJAsp+IuIZKG4gr+ZjTWzl8xsbTDZ+vig3MzsB2ZWbmbrzGxcxD43mtlbwe3GeBsgIiItF++A77eBb7j7b81sRnD/MmA6oakbhwITgEeBCWbWHbgHKCM0cfvLZrbU3Q/EWQ8REWmBeIO/A52D5S6cmox9JvCYuzvwkpl1NbO+hD4YVrj7fgAzWwFMo940kK3lJzeWcbLa2+JQIiIpLd7g/0/AcjN7kFAKaVJQ3h/YGrHdtqAsVnkDZjYHmANQXFwcZzVDrjzrjIQ8johIumsy+JvZSqBPlFV3A1cCd7j7r83s48BPgMlAtFlUvJHyhoXuC4GFAGVlZequi4gkUJPB391jXhHNzB4DvhTc/RXw42B5GzAwYtMBhFJC2wilfiLLn292bUVEJCHiPdVzO3BpsHwF8FawvBT4THDWz0TgoLvvAJYDU8ysm5l1A6YEZSIi0obizfl/Afi+meUBFQQ5emAZMAMoB44BNwG4+34zuw9YHWx3b3jwV0RE2k5cwd/d/wScF6Xcgdti7LMIWBTPcUVEJD76ha+ISBZS8BcRyUIK/iIiWchC6fnUZmZ7gHfjeIiewN4EVSeZMqUdoLakqkxpS6a0A+Jry5nu3ivairQI/vEyszXuXpbsesQrU9oBakuqypS2ZEo7oPXaorSPiEgWUvAXEclC2RL8Fya7AgmSKe0AtSVVZUpbMqUd0EptyYqcv4iI1JUtPX8REYmg4C8ikoUyOvib2TQz2xTMJTwv2fVpDjPbYmavhedFDsq6m9mKYN7jFcEVURudKzlJdV9kZrvNbH1EWYvrnux5nmO04+tm9n7wuqwNpi0Nr5sftGOTmU2NKE/6+8/MBprZH8xso5ltMLMvBeXp+LrEaktavTZm1s7MVpnZq0E7vhGUl5rZ34Ln95dmVhCUFwb3y4P1JU21r1ncPSNvQC7wNjAIKABeBUYmu17NqPcWoGe9sm8D84LlecC3guUZwG8JTZIzEfhbkut+CTAOWH+6dQe6A5uDv92C5W4p0I6vA3dG2XZk8N4qBEqD91xuqrz/gL7AuGC5CHgzqHM6vi6x2pJWr03w3HYKlvOBvwXP9RPA7KD8R8CtwfIXgR8Fy7OBXzbWvubWI5N7/uOBcnff7O6VwGJCcwuno5nAfwXL/wV8JKL8MQ95CQjPlZwU7v4CUP8S3S2t+1SCeZ7d/QAQnue5zcRoRywzgcXufsLd3yF0GfPxpMj7z913uPvfg+XDwEZCU6em4+sSqy2xpORrEzy3R4K7+cHNCc2J8mRQXv81Cb9WTwJXmpkRu33NksnBv9nzBacYB54zs5ctNI8xwBkemgyH4G/voDwd2tjSuqdym+YGqZBF4TQJadSOIF1wLqGeZlq/LvXaAmn22phZrpmtBXYT+iB9G/jA3aui1Km2vsH6g0AP4mxHJgf/Zs8XnGIudPdxwHTgNjO7pJFt07WNkIB5ntvYo8BgYCywA/huUJ4W7TCzTsCvgX9y90ONbRqlLKXaE6UtaffauHu1u48lNJXteOCsRurUKu3I5OAfax7hlObu24O/u4GnCL0xdoXTOcHf3cHm6dDGltY9Jdvk7ruCf9ga4D859fU65dthZvmEguXP3f03QXFavi7R2pLOr427f0BoHvOJhFJs4Qm2IutUW99gfRdCacm42pHJwX81MDQYQS8gNFCyNMl1apSZdTSzovAyoTmO1xOqd/jsihuBJcFyrLmSU0lL656S8zzXG0uZReh1gVA7ZgdnZJQCQ4FVpMj7L8gN/wTY6O7fi1iVdq9LrLak22tjZr3MrGuw3B6YTGj84g/Ax4LN6r8m4dfqY8DvPTTiG6t9zdNWI9zJuBE6c+FNQvm0u5Ndn2bUdxCh0ftXgQ3hOhPK7/0OeCv4291PnTXwcNC+14CyJNf/cUJfu08S6pX8w+nUHfgcocGrcuCmFGnHfwf1XBf80/WN2P7uoB2bgOmp9P4DLiKUClgHrA1uM9L0dYnVlrR6bYCzgVeC+q4HvhaUDyIUvMuBXwGFQXm74H55sH5QU+1rzk2XdxARyUKZnPYREZEYFPxFRLKQgr+ISBZS8BcRyUIK/iIiWUjBX0QkCyn4i4hkof8PE3z38TlHrqkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = SARSAAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 3000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_sarsa(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_sarsa(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "显示最优价值估计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>-3.906945</td>\n",
       "      <td>-3.926261</td>\n",
       "      <td>-3.354537</td>\n",
       "      <td>-3.844973</td>\n",
       "      <td>0.169093</td>\n",
       "      <td>-7.107725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>-2.032215</td>\n",
       "      <td>-0.451203</td>\n",
       "      <td>-2.069749</td>\n",
       "      <td>-2.175503</td>\n",
       "      <td>7.713113</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>-3.152140</td>\n",
       "      <td>-3.810306</td>\n",
       "      <td>-3.137979</td>\n",
       "      <td>-3.851483</td>\n",
       "      <td>2.509615</td>\n",
       "      <td>-7.105844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>-5.791234</td>\n",
       "      <td>-5.933590</td>\n",
       "      <td>-5.896414</td>\n",
       "      <td>-5.953635</td>\n",
       "      <td>-7.072424</td>\n",
       "      <td>-7.107966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>495</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>496</td>\n",
       "      <td>-2.843914</td>\n",
       "      <td>-2.876611</td>\n",
       "      <td>-2.881306</td>\n",
       "      <td>-2.877551</td>\n",
       "      <td>-3.600000</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>497</td>\n",
       "      <td>-1.680592</td>\n",
       "      <td>-1.225120</td>\n",
       "      <td>-1.385155</td>\n",
       "      <td>-1.430747</td>\n",
       "      <td>-3.600000</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>498</td>\n",
       "      <td>-2.917878</td>\n",
       "      <td>-2.961647</td>\n",
       "      <td>-3.112380</td>\n",
       "      <td>-2.918050</td>\n",
       "      <td>-5.008434</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>499</td>\n",
       "      <td>-0.707040</td>\n",
       "      <td>-0.424800</td>\n",
       "      <td>-0.707040</td>\n",
       "      <td>8.373279</td>\n",
       "      <td>-3.600000</td>\n",
       "      <td>-3.636000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            0         1         2         3         4         5\n",
       "0    0.000000  0.000000  0.000000  0.000000  0.000000  0.000000\n",
       "1   -3.906945 -3.926261 -3.354537 -3.844973  0.169093 -7.107725\n",
       "2   -2.032215 -0.451203 -2.069749 -2.175503  7.713113 -3.636000\n",
       "3   -3.152140 -3.810306 -3.137979 -3.851483  2.509615 -7.105844\n",
       "4   -5.791234 -5.933590 -5.896414 -5.953635 -7.072424 -7.107966\n",
       "..        ...       ...       ...       ...       ...       ...\n",
       "495  0.000000  0.000000  0.000000  0.000000  0.000000  0.000000\n",
       "496 -2.843914 -2.876611 -2.881306 -2.877551 -3.600000 -3.636000\n",
       "497 -1.680592 -1.225120 -1.385155 -1.430747 -3.600000 -3.636000\n",
       "498 -2.917878 -2.961647 -3.112380 -2.918050 -5.008434 -3.636000\n",
       "499 -0.707040 -0.424800 -0.707040  8.373279 -3.600000 -3.636000\n",
       "\n",
       "[500 rows x 6 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(agent.q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "显示最优策略估计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>495</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>496</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>497</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>498</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>499</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       0    1    2    3    4    5\n",
       "0    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "1    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "2    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "3    0.0  0.0  0.0  0.0  1.0  0.0\n",
       "4    1.0  0.0  0.0  0.0  0.0  0.0\n",
       "..   ...  ...  ...  ...  ...  ...\n",
       "495  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "496  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "497  0.0  1.0  0.0  0.0  0.0  0.0\n",
       "498  1.0  0.0  0.0  0.0  0.0  0.0\n",
       "499  0.0  0.0  0.0  1.0  0.0  0.0\n",
       "\n",
       "[500 rows x 6 columns]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "policy = np.eye(agent.action_n)[agent.q.argmax(axis=-1)] \n",
    "pd.DataFrame(policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 期望 SARSA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExpectedSARSAAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        self.action_n = env.action_space.n\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = self.q[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done):\n",
    "        v = (self.q[next_state].mean() * self.epsilon + \\\n",
    "                self.q[next_state].max() * (1. - self.epsilon))\n",
    "        u = reward + self.gamma * v * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_qlearning(env, agent, train=False, render=False):\n",
    "    episode_reward = 0\n",
    "    observation = env.reset()\n",
    "    while True:\n",
    "        if render:\n",
    "            env.render()\n",
    "        action = agent.decide(observation)\n",
    "        next_observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        if train:\n",
    "            agent.learn(observation, action, reward, next_observation,\n",
    "                    done)\n",
    "        if done:\n",
    "            break\n",
    "        observation = next_observation\n",
    "    return episode_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 817 / 100 = 8.17\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXxU9b3/8dcnO2sIJKxJWMO+GwOICMoWQMSitmDdvXWptnr7u9frdtW611Zb+7t2wZZbvV3UWq3U4gKtVW/rAqgoyGIELDsoyiqBJN/7x5wMk2SyTiYzmfN+Ph7zyMz3nDnn+52c+cz3fL/f8z3mnENERPwlKdYZEBGRlqfgLyLiQwr+IiI+pOAvIuJDCv4iIj6UEusMNER2drbr06dPrLMhItKqrFq16lPnXE64Za0i+Pfp04eVK1fGOhsiIq2KmX1S2zI1+4iI+JCCv4iIDyn4i4j4kIK/iIgPKfiLiPiQgr+IiA8p+IuI+JCCfyM9//4OPvnscJW0svIKDh49zv4jx2usf6ysgiPHyth/5DiNmT77cGkZpWXlrN91gFc27OGDbfv5/kvrWb/rAKs+2ce6nQeC65ZXOBb/72ZKy8oBOHj0OMfLKxq0ny2fHubo8XIOHj3O1n1HOFRaxt6DpVXWcc7x6GubKC0rxznH1n1H+OxQYJ0vjhzjh8s28uSKfwbXP3j0OH98dzsVFY6nVm5l1Sf7eG/rFzX2ffR4OV8eC+R5/5FAnvd/eZxtnx+pkodw5TlyrIxn3tnGgaPHKa+o+bk65/j88DGOe/8bgHU7A59dbUrLyoP/q4ow26zNl8fKOXq8PPj6bxv2sGb7/irl/Mu63by56TOOHCtj+xdf8tf1uwPl/rJq/qu/rmt/Xxw5Vus6B44e57n3tnPg6HHKvM+uvMKx/8vAZ1F5fFU6XFrGsbIKPjtUyoGjJ47jNdv38/gbW1iyekfwGH561bYq5S2vcDy1Yiv7Dh+r8blVz+PqrV/wyvo9HDh6nC+OHAv7nQnn5bW72HPgKK9u3MvWfUcAOOTl+elV21j02sc13uOcY/+R4xwqLePpVduC3783Pv6sxnGw+8BRXlm/h5I9BwHC5mvrviPBzzLUodKyGsdnaVk5nx8+Fjy+q6t+3B4uLWP5h7u5/bk1VT7baLLWMJ9/YWGhi5eLvPrc+GcAttw/h82fHmbzp4d47r0dPPfeDgCe+eYpJJsx75G/06ltKt07ZrB+18Hg+4v6dubtzfu4f/4Iyioct/5xDQCb7p3N8x/s5Nu/e5c75g7ljj99yMBu7dm4+1CtefnxwjHc9twavgg5UCcVZPP6R58yMjeTH5w3ihk/fI1endqw/YsvSUtO4qTeWbyx6TPOOymX36/aVuu2Jw/M4dY5Q0hOMr734npeWhsIVnNH9eRPq3eEfc/pg3LISE3m9Y8+5VBpGR0yUjh4tCy4fPElhVz2q5Xkd27Ln649lVF3vgzAxrtnMfDWF8Jus/IzKOydxR1nDSOvc1uOlVVw8j3Lg+uMzM3k/W37g2Wqr2zfnlrAn9/fQY/MNvzwa6M5cqyMvKy29Lt5adj1f/eN8Sx89E2um1rAw3/5KJj+wLkjOX1Q12Be7pg7lL457bl48dsAtE1L5o65w7jhD++H3e53pg/koWUbGdKjI6cVZPPz1zYFl/3gvFHc/OwHHCs7EVRSkow2qckcLD3xmbZLSyarXRqHSsuYPaIHz6/ewYGQzzzU4O4dWL/rIJef2pdf/u9mstun8/CC0fTIzOCMB1+t9fMKldU2lc+rBcavFebx5MqtVdL+cPUpPPraJl5cuwsIHN87Dxxl4v1/rbHNP14zkb+u38PqrV/w4c4D3DBzEKcMyA6u+/hlRVy0+O3gcQwnjvO6tE9P4VBpGX2z27H508MkJxnfOmMAP1oe+B+uvn0GG3Yd5Ks/f6PK+66a3J+fvfoxF4zPZ/eBUt755HM+O3ziR2xhUR7fnlrAAy9uoLSsnKUf7AouO39cPtnt0vjxX0uCaaNyMxnWK5PfvvVPHrusiI27DnLP0nUAXD2lPx0zUvnei+tr5P8nXx/Lmu37yWyTypWT+9dZ1tqY2SrnXGHYZQr+9bv0v9+mb3Z7bps7NBj8u3ZIZ0+1GjLAXWcP5z+9gC4i0hy23D+nSe+rK/jHrNnHzIrNbIOZlZjZjbHKR12Ol1ewdd8RXtmwl8V/31xlWbjALyLSWsQk+JtZMvAIMAsYCiw0s6GxyEt1j/1jC//8LNCmeNtza5n0wCvBZRtCmm9q8+Wx8KfcIiLxJFY1/yKgxDm3yTl3DHgCmBejvAQdKi3j9iVrWbAo0Ab4+kd7qyyf+aPX6t3GvUtrtt2JtDaj8joxKq9TTPY9Jr/l93vZxL7cOmcIpw/KafD+8zu3bdQ+vj4un5P7ZDV4/VMHZHP26J58cMeMRu2noWI1q2cvILSHaBswLnQFM7sCuAIgPz+/RTJ1158+BKjSSSmxZwaRdk21TUvmSC0jLxoiXKfmOWNz+cM7tXcst6T3bpvO6DuXVUkb0qNjlVFhAL06teH+c0aQkpTEwkffBODXl4/jgl++VWW9O88aViX4Hyot40hpGVnt0nj9o71c9quV9MjMYOf+o0Cg8/uGp2t2bK+8dRqFdwc6xCvbrXcfOMqqTz7nm795B4Cnr5rAuT97o8Z7AZIMQgcQhRscsOne2Xx+5Bhd2qcH0yr75gBunTOEKYNymPbQa8F8/OPjTzn/0UCZJ/Trwm1zAw0P/zKpHxDo53tlQ9XKX6XB3Tvw4vWnAfDx3kNM9TrL7z57OMXDu/Poa5v4+WubGNazI3/+9qRgXu75ygg+O1TKSd7nsf6uYjJSk/nxXz7ioWUbmTe6J8+9t4Prphbwr9MHht13c4pVzd/CpFX5ejvnFjnnCp1zhTk5YaejbnbVv9ytoC884V17+gA239fwzq7x/TpXeX1ynywmFWTzzn9O5xuT+oZ9z/fOGUFGahKZbVJZf1cxG+4upku7NADunDeMNd+dyb3zR/DgeaOC73l4wWge/Ooo/v/CMay+fQanD6p6jN4yewipyScO81P6d6k3749dVsSs4d35w9WnBNNeuG4S988fEQzEv/mXcUwb0pX5Y3pVeW+ntmk1tlfYu2otc+awbvzX+WOYVJDDhP5deHjBaC4/tS/9ctoBgcEKldqkJVd5b/v0FLp2zCA1OYksb1+nD+7K01dNoHhYd84ZmwsERuFU+uaU/mS3T+dnF5zEwqK8YHq3jhnMHtEj+Dq/S9UadNuQfRf26cx980cEP5+0lEDIym4fyMOgbh1ISrIqgb+6qUO6MaBrB75WmMfiSwJ9n6f0z2b17TPYfN9sfnfF+BrvefSiQtZ+d2bw9aBuHQBYd2cxS649NZjeP6c9HTMCdegkM7Lbp3NeYaCsfbMDn+v10woY0qMjAO3ST9S305IDZbn29AF8cMeM4OfaqW1qrWVpTrGq+W8D8kJe5wLhxw+2gD0HjrLrwNEqaeUVLjisTJrPA+eMrHXo4/q7itny2WGKf/R6MK1DRuMO0ZtnD6FfTnsyUpJYvm43UwZ1JSM1EExumTOUrHZpPPDiBvrntOPjvYHrNTq3S+eDOwJf9FTvCzlrRHd+/eY/MQKBD+Cck3IZ168zn3x2hIkDAkFu7qieAPzy4pM5XlHBoFtfBGD60G5847R+wVrfb78xnl+8von26Sn079qezXsPV/kcvlaYx+SBOUweGPgReeT8sZw2MJsOGakM6dGRBUX5HD1eTkZqcnDfz7y7Pexn8KtLT2ZUbifapCUzcUAXrvp1oIb98wurDvqYN7oX80b3Cn726SlJzB/Ti7+XfMpAL9iFMyY/i59dMJbJA7vSJi2Zwj6BH9yNd88iOcno7w2ZvaF4MADFw7tTPLx7je08fdUEumdm0LVDBuvvKuaZd7Zz87MfkJfVll9c1Jcvj5dzWkEOmW1TWVh04uz/t98YR/+c9mS1TSMpXDUyxFmjegabZ7537sgqyzLb1B5kU5KTSElO4sLxvZkzsgeFvbOocAR/fELNGt6DJ1duxby8DOjankUXnhT8P10/bSDXTwvU5DNSk1l/VzEASV7mk5KMDhmpnD64K7/6xxZO7tO5xj6iIVbBfwVQYGZ9ge3AAuD8GOWFqQ++WmXsNAbvb6t5UZJELrdzGyDQlDNzaPfgOHCA9JQkBnfvyGOXFfHy2l385q0TF46lJSdxLMwFNj0zM9jhNT08/61TGd4rM7iseHiPGutXmj60O7k7D/Dqxr0YJ4J+peBZn1WNLrlZbcnNqtnWm5RkpCcls+xfT+OxN7YEA87rN5wevLCqskkBCF7IM75fZy4/tR/ThnStsr05I2vmvfJHrFLoD1ioKYNObKuuzyDcttulpzBjWM1AXV247YYLjHUpDAly1cs2bWi3Wt93Sv/sWpdV2nTvbMzArJ5fh3qEng3VZtaI7jy5cmuVM626PsPqZa00eWAOH987m+T6ftGaSUyCv3OuzMyuBV4CkoHFzrm1scgLUDXwS1RcObkfpccrGJsf+IIkm1WJq9+eWhD8ok4emMO2zwMjrioDrfNaBRcW5fHe1v2s23mAW+cM4elV24LBv3tmRqPyVPklCxcfKtuZG/s9LOjWgbvPHhF8nde5bZVT3EqdvWal4T0zmV5HoKvL0J6ZVYL/i9dPClub/ceNZ1S5ajfeRRivgRO16pYwZVDXJo/Dr66lAj/E8DaOzrmlQPhLKqOksrZV2y+vNN5pA3N4bWP4jrFQvTq14aIJfYJXrIa269519nAuHN+7yvrnF+XTL7t9sA2/siZ+25nDKKuo4AcvbeCC8b3p3C6N7zy1mscvKyK7jnbfcOq6wLFymYXtnorcsJ6Z/P6qCYyOYETN3WcP50+rdwT7GwZ37xh2vZ6d2tCTNk3eT2O9ffPUWq8yrotDnWwtqVXcw7e5jL7zZY4er2jQr3Skp4uJblC3Djxw7kiG9ezIgFvCT80QqjLOpqUkceOswUwb0pUHX94IQOcwnZVmxoSQTtLQVpgOGal8d17gdHz+2Fzme52NjeFCQk1SmP91ZX6jeRhE2rab2SaVF66bRO8ujRtyGG1dO2bQNfzvUAPpu9cSfDWx29HjDZvsTGqqPj75+mkFjMrrREpyww6h0Fr2VZP7M6BrBy6d2Dfstut6f6SnxaE1+RPt+mH25/00tOBZeJMM6dGRtmmJUYfT6LqW5avg31BfHitn1/6j9a/oI6FtyT/5+lhmhQzVu3JyvxrrP3jeKL4TMlY53Pe6qG9nttw/h64d62+rr6uW3lR1xZr26YHytkmQwNoa1NLHLlGi4B9GWYXjql+vinU24krlsDWgyhhtgK4dagbvmcO7c8nEPsHXkdbq5nj7bM6a+Il2/Zr+feYgbp49OLhfaTmK/S1DwV8a5JJT+gAnLlwJFXpWcJI33C0tOYmOGSfSIz2jf+iro1l567So9MWE22abtGSuOK1/i46+8D21+7QondMKEBhe+Wodo3bMjDXfnUlKmGA4f0wvjpdXcM7YXJIsMBVA5ZjvSyf24b//vqVRN7IJJy0lqdGjecI556ReLFm9g4sm9OE/vOkIFN7jg5p9WpZq/j626MKTgs+H9ax/eEb79JSww2STkoyFRfmkpQSuigydaiBaQyWbqmuHDF64bhK9OrUJduoq2MSXeDtmEpWCv4+FXoVY10RSRREMSeyVFRhfntMh8lp7tCjYiB/5stmnvML5si132pBuLF+3O+yy6tMbVDprVE8eXjC6yfu85JQ+5GW1afJVrNHUEmP5peH0/2hZvqz59795aYvdJDme/OLiwhoXuL1501ReuG5S2PXnje7JbXOHRtTJmpxkzBjWPS4vmquoY7SPtLy6Rl9J8/Nl8AcY/J8vsvuAxvJ3z8wITjdb3X3zRzRLJ2u8GuFNApcdx01SItHiy2afSvN/8g9e+bcpsc5GxFKSjLKKho+m+fXl4/jscM17EN8xdyj/3Pdl8H7FzXlBVTy6oXgwc0f1rHP6Ymk5J0b7JPZxFy98W/MH2P7Fl61+zv7Km1M0xqkF2cF53ENdMrFvg0b9JIrU5CRG5sbmVoVSk4b5tyxfB/9EcMbgbs3aQRZ6B6dEr/lLfBnUPXAGNiovs541pTko+CeAB78aGI0TOm6/qYpDhn8q9ktLmjggm1f/fQpfGdP4WVql8Xzd5p8ozhrVk7O82wme3CeLFVs+b/K2Qm+CoZq/tLTeXWpOHyLR4Zuaf3ktHaJvb/6shXMSXbedOSxsjb0pc74r9IskLt8E/x8t3xg2/amV21o4Jw131eT+jX7PiNxM/vjNiTXSf3/lhEZvSxV/kcTlm+C/bufBWGeh0ZrrtnbThnRr0Jz51WnInUji8k3wr02ks01GVRxnTURaNwX/WGegDrHK29VTGt/cJCKti++D/7v//CLWWahVrM5K/qN4cINuci8irZfvg388a8SMDSIijRJR8Dez88xsrZlVmFlhtWU3mVmJmW0ws5kh6cVeWomZ3RjJ/hNdPHdHiEjrFmnNfw0wH3gtNNHMhgILgGFAMfATM0s2s2TgEWAWMBRY6K0rYTTXaJ8rJ/drlu2ISOKI6Apf59w6CDskcB7whHOuFNhsZiVAkbesxDm3yXvfE966H0aSD6nbyRHciUtEElO02vx7AVtDXm/z0mpLr8HMrjCzlWa2cu/e2m8snshqa/a5aELvls2IiCSceoO/mS03szVhHvPqeluYNFdHes1E5xY55wqdc4U5OTn1ZdNXpg1p+C0R7/3KiCjmRERaq3qbfZxz05qw3W1AXsjrXGCH97y2dKmmrqGe/2/6QDJSk2tdDjAyN5Pzx+U3d7ZEJAFEa1bPJcBvzewhoCdQALxNoOZfYGZ9ge0EOoXPj1IeWr3ahnqmJBnfmlrQspkRkYQS6VDPr5jZNmAC8GczewnAObcWeIpAR+6LwDXOuXLnXBlwLfASsA54yltXwvjX6QPDpicn1T3nzvBemSwsyuPhBWOikS0RSQCRjvZ5Fni2lmX3APeESV8KLI1kv37RuV0aPTMz6NI+nQ+27w+mpyTXHfyTk4z75o+MdvZEpBXTFb5x7h83TeXJK8dXSeuf0z5GuRGRRKE7ebUiSQab7tOcOyISOdX8RUR8SMG/FUlJ0r9LRJqHL6JJyZ6DLF+3O9bZaLLK+w+n1tPRKyLSUL4I/s++uz3WWYhIanLg33T2mLAzYYiINJovOnzLylv33MgZqcm8d9t02qf74t8lIi3AF9GkLAHuitKpbVqssyAiCcQXzT7lCRD8RUSaky+Cf6zuhdsYxcO6c0r/LrHOhoj4hD+Cf6wz0ABfPTmXpJo3xRERiQpfBH8REalKwT+ONNc9e0VE6uOL4B/PTf6nD9JdykSk5fki+IuISFUK/iIiPqTgLyLiQ74I/upIFRGpyhfBX0REqvJF8I+30T5ZbVNjnQUR8TlfBP/WIt5+pEQkcfki+CumiohU5YspnePRY5cVsfvAUV74YGessyIiPqTgHyOTBwau7A0X/M8fl8/0Id1aOksi4iMRNfuY2ffNbL2ZvW9mz5pZp5BlN5lZiZltMLOZIenFXlqJmd0Yyf5bK6tn9s45I3pw+uCuLZQbEfGjSNv8lwHDnXMjgY3ATQBmNhRYAAwDioGfmFmymSUDjwCzgKHAQm9dX7l9ru+KLCJxJqJmH+fcyyEv3wTO9Z7PA55wzpUCm82sBCjylpU45zYBmNkT3rofRpKPuhTds5w9B0ujtflGe+6aiYzK6xR22XemD2Tj7ncYkZvZwrkSEb9pztE+lwEveM97AVtDlm3z0mpLr8HMrjCzlWa2cu/evU3OVDwF/vvmj2BkHYG9sE9nVt46jY4Zug5ARKKr3uBvZsvNbE2Yx7yQdW4ByoDfVCaF2ZSrI71monOLnHOFzrnCnJzWN+1xj8yMGmkLi/JrtPdPGxro2O2b3b5F8iUiAg1o9nHOTatruZldDJwJTHUnbpa7DcgLWS0X2OE9ry09ofzP5UVMe+i1etc7vyifuaN6qrYvIi0q0tE+xcB/AGc5546ELFoCLDCzdDPrCxQAbwMrgAIz62tmaQQ6hZdEkod41dCrdc1MgV9EWlyk4/z/C0gHlnnNGW86565yzq01s6cIdOSWAdc458oBzOxa4CUgGVjsnFsbYR5ERKSRIh3tM6COZfcA94RJXwosjWS/rUE9Q/lFRGLKF3P7xIImaROReKbgH2XZ7dNjnQURkRoU/KOsk+buF5E4pOAvIuJDCv4iIj6k4B8l6u8VkXim4B9lGvEpIvFIwT/KdAYgIvFIwV9ExIcU/KNMzT4iEo8U/EVEfEjBP0o0vYOIxLNIZ/WUBnjmm6ewYdfBWGdDRCRIwT9KQmf1HJufxdj8rNhlRkSkGjX7ROiSU/pw4fjeNdLV7CMi8Uw1/wgM6dGRO84aBsD/vPlJ2HU0r7+IxCPV/KNMZwAiEo8U/Otw0YSazTkiIolAwb8Od84bXufyhrToqNlHROKRgr+IiA8p+EeJ05RuIhLHFPxFRHxIwb8Wf7h6Qr3rqD1fRFqriIK/md1lZu+b2Xtm9rKZ9fTSzcx+bGYl3vKxIe+52Mw+8h4XR1qA6GmeyG6a11NE4lCkNf/vO+dGOudGA88Dt3nps4AC73EF8FMAM+sM3A6MA4qA282s1c57oJq/iLRWEQV/59yBkJftOHHjqnnA4y7gTaCTmfUAZgLLnHP7nHOfA8uA4kjyEK90cZeIxLOIp3cws3uAi4D9wOleci9ga8hq27y02tLDbfcKAmcN5OfnR5rNRktSrV5EEli9NX8zW25ma8I85gE4525xzuUBvwGurXxbmE25OtJrJjq3yDlX6JwrzMnJaVhpmlFaSmQtYmoSEpF4Vm/N3zk3rYHb+i3wZwJt+tuAvJBlucAOL31KtfS/NXD7LapDemq969TVmatmHxGJZ5GO9ikIeXkWsN57vgS4yBv1Mx7Y75zbCbwEzDCzLK+jd4aXFnfyu7Rtlu3oDEBE4lGkbf73m9kgoAL4BLjKS18KzAZKgCPApQDOuX1mdhewwlvvTufcvgjzEDMNCew6AxCReBRR8HfOnVNLugOuqWXZYmBxJPuNFw8vGBPrLIiINImu8G2ihUV59M1uV+96avYRkXik4C8i4kMK/lGitn4RiWcK/iIiPqTgLyLiQwr+IiI+pODfZBrGIyKtl4J/lLRJSwYgN6tNjHMiIlJTxLN6JqKvjAk70WgVUwbVPdlc3+x2/PTrY5lYkN1c2RIRaTYK/mE8eN6oOpe/cdMZ9Misv0Y/a0SP5sqSiEizUrNPGEn1TOafnpLcQjkREYkOBX8RER9S8BcR8SEFfxERH1LwbwKN8BeR1k7BX0TEhxT8RUR8SMG/CXSDFhFp7RT8RUR8SMFfRMSHNL1DM/vFRYWs+GRfrLMhIlInBf9mNm1oN6YN7RbrbIiI1EnNPk1gGukvIq1cswR/M/s3M3Nmlu29NjP7sZmVmNn7ZjY2ZN2Lzewj73Fxc+xfREQaJ+JmHzPLA6YD/wxJngUUeI9xwE+BcWbWGbgdKAQcsMrMljjnPo80HyIi0nDNUfP/IXADgWBeaR7wuAt4E+hkZj2AmcAy59w+L+AvA4qbIQ8tS60+ItLKRRT8zewsYLtzbnW1Rb2ArSGvt3lptaXHrf86f0zNRFczSUSkNak3+JvZcjNbE+YxD7gFuC3c28KkuTrSw+33CjNbaWYr9+7dW182o+bMkT1jtm8RkWipt83fOTctXLqZjQD6AqstMN9BLvCOmRURqNHnhayeC+zw0qdUS/9bLftdBCwCKCwsjK+6tpp9RKSVa3Kzj3PuA+dcV+dcH+dcHwKBfaxzbhewBLjIG/UzHtjvnNsJvATMMLMsM8sCZnhpcePm2YNjnQURkaiL1kVeS4HZQAlwBLgUwDm3z8zuAlZ4693pnIury2FH9OpU7zqa2E1EWrtmC/5e7b/yuQOuqWW9xcDi5tqviIg0nq7wrUa1ehHxAwX/alx8dS2LiESFgr+IiA8p+DfAilumcfvcobHOhohIs1Hwr8aFueYsp0M6eVltg6/VLSAirZ2CfwOpI1hEEomCv4iIDyn4N4HpNEBEWjkFfxERH/Jl8H/wvFGNfo/G/4tIIvFl8E9JrqPZRkFeRHzAl8E/WrX4C8bnR2fDIiLNLFqzeia0cOcNW+6f0+L5EBFpKl/W/DVYR0T8zpfBvynNPiPzMps/IyIiMeLL4B/OhH5d6lzetUMGbVKTAZ05iEjrp+BfjQb7iIgf+C74D+7egUkF2TXSG1Kb//55I+mX3Y6MlOQo5ExEpOX4brTPKf2z6dI+vUnvPXNkT84c2bOZcyQi0vJ8V/MXEREF/6DvnjWMSQXZnNQ7K9ZZERGJOt81+9SmoFsH/ufycbHOhohIi1DNX0TEh3wb/O+YO5THLyuKdTZERGIiouBvZneY2XYze897zA5ZdpOZlZjZBjObGZJe7KWVmNmNkew/EpdM7MtpA3NitXsRkZhqjjb/HzrnfhCaYGZDgQXAMKAnsNzMBnqLHwGmA9uAFWa2xDn3YTPkQ0REGihaHb7zgCecc6XAZjMrASrbWEqcc5sAzOwJb10FfxGRFtQcbf7Xmtn7ZrbYzCrHSfYCtoass81Lqy29BjO7wsxWmtnKvXv3NkM2RUSkUr3B38yWm9maMI95wE+B/sBoYCfwYOXbwmzK1ZFeM9G5Rc65QudcYU6O2uZFRJpTvc0+zrlpDdmQmT0KPO+93AbkhSzOBXZ4z2tLFxGRFhLpaJ8eIS+/Aqzxni8BFphZupn1BQqAt4EVQIGZ9TWzNAKdwksiyYOIiDRepB2+D5jZaAJNN1uAKwGcc2vN7CkCHbllwDXOuXIAM7sWeAlIBhY759ZGmAcREWmkiIK/c+7COpbdA9wTJn0psDSS/YqISGR8e4WviIifKfgDI3rp/rwi4i++n9Xz+W+dSl7ntrHOhohIi0ro4H+8vKLedYar1i8iPpTQzT77vzwe6yyIiMSlhA7+IiISXkIH/3BzSYiISKIHf1P4FxEJJ7GDf6wzICISpxI7+Cv6i4iEldDBX0REwkvo4G9q+BERCSuhg79iv4rx4TgAAAaESURBVIhIeAkd/NXmLyISXmIH/1hnQEQkTiV28FfVX0QkrIQO/iIiEl5CB3/V+0VEwkvo4C8iIuEp+IuI+JBvgv+tc4bEOgsiInHDN8FfREROSOjbOIY6f1w+H+48wLenDoh1VkREYi7imr+ZfcvMNpjZWjN7ICT9JjMr8ZbNDEkv9tJKzOzGSPdfd95OPG+blsJDXx1Np7Zp0dyliEirEFHN38xOB+YBI51zpWbW1UsfCiwAhgE9geVmNtB72yPAdGAbsMLMljjnPowkH7XmT4M9RUTCirTZ52rgfudcKYBzbo+XPg94wkvfbGYlQJG3rMQ5twnAzJ7w1o1K8G+Tlszg7h24YHzvaGxeRKTVijT4DwQmmdk9wFHg35xzK4BewJsh623z0gC2VksfF27DZnYFcAVAfn5+kzP44vWnNfm9IiKJqt7gb2bLge5hFt3ivT8LGA+cDDxlZv0If3GtI3wfgwu3X+fcImARQGFhYdh1RESkaeoN/s65abUtM7OrgWeccw5428wqgGwCNfq8kFVzgR3e89rSRUSkhUQ62uePwBkAXoduGvApsARYYGbpZtYXKADeBlYABWbW18zSCHQKL4kwDyIi0kiRtvkvBhab2RrgGHCxdxaw1syeItCRWwZc45wrBzCza4GXgGRgsXNubYR5EBGRRrJArI5vhYWFbuXKlbHOhohIq2Jmq5xzheGWaXoHEREfUvAXEfEhBX8RER9qFW3+ZrYX+CSCTWQTGIXkJ34rs9/KCyqzX0RS5t7OuZxwC1pF8I+Uma2srdMjUfmtzH4rL6jMfhGtMqvZR0TEhxT8RUR8yC/Bf1GsMxADfiuz38oLKrNfRKXMvmjzFxGRqvxS8xcRkRAK/iIiPpTQwb8l7xccbWa22Mz2eJPoVaZ1NrNlZvaR9zfLSzcz+7FX7vfNbGzIey721v/IzC6ORVkayszyzOwVM1vn3SP6Oi89YcttZhlm9raZrfbK/F0vva+ZveXl/0lvVly8mXOf9Mr8lpn1CdlW2PtoxyMzSzazd83see91opd3i5l9YGbvmdlKL61lj2vnXEI+CMwa+jHQj8BU06uBobHOVwTlOQ0YC6wJSXsAuNF7fiPwPe/5bOAFAjfVGQ+85aV3BjZ5f7O851mxLlsdZe4BjPWedwA2AkMTudxe3tt7z1OBt7yyPAUs8NJ/BlztPf8m8DPv+QLgSe/5UO+YTwf6et+F5FiXr45yfwf4LfC89zrRy7sFyK6W1qLHdSLX/Ivw7hfsnDsGVN4vuFVyzr0G7KuWPA94zHv+GHB2SPrjLuBNoJOZ9QBmAsucc/ucc58Dy4Di6Oe+aZxzO51z73jPDwLrCNwONGHL7eX9kPcy1Xs4AvfNeNpLr17mys/iaWCqmRkh99F2zm0GQu+jHVfMLBeYA/zCe20kcHnr0KLHdSIH/17UvF9wr1rWba26Oed2QiBQAl299NrK3mo/E+/0fgyBmnBCl9trAnkP2EPgC/0x8IVzrsxbJTT/wbJ5y/cDXWhdZf4RcANQ4b3uQmKXFwI/6C+b2SoL3K8cWvi4jvRmLvGstvsI+0FtZW+Vn4mZtQf+AFzvnDsQqOiFXzVMWqsrtwvc+Gi0mXUCngWGhFvN+9uqy2xmZwJ7nHOrzGxKZXKYVROivCEmOud2mFlXYJmZra9j3aiUOZFr/nXdRzhR7PZO//D+7vHSayt7q/tMzCyVQOD/jXPuGS854csN4Jz7AvgbgXbeTmZWWVkLzX+wbN7yTALNg62lzBOBs8xsC4Gm2TMInAkkankBcM7t8P7uIfADX0QLH9eJHPz9cL/gJUBlD//FwHMh6Rd5owTGA/u908iXgBlmluWNJJjhpcUlry33l8A659xDIYsSttxmluPV+DGzNsA0An0drwDneqtVL3PlZ3Eu8FcX6A2s7T7accU5d5NzLtc514fAd/Svzrmvk6DlBTCzdmbWofI5geNxDS19XMe61zuaDwK95BsJtJneEuv8RFiW3wE7geMEfvEvJ9DW+RfgI+9vZ29dAx7xyv0BUBiyncsIdIaVAJfGulz1lPlUAqex7wPveY/ZiVxuYCTwrlfmNcBtXno/AsGsBPg9kO6lZ3ivS7zl/UK2dYv3WWwAZsW6bA0o+xROjPZJ2PJ6ZVvtPdZWxqaWPq41vYOIiA8lcrOPiIjUQsFfRMSHFPxFRHxIwV9ExIcU/EVEfEjBX0TEhxT8RUR86P8AOBU7I6oM3DEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = ExpectedSARSAAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_qlearning(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_qlearning(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Q 学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QLearningAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = self.q[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done):\n",
    "        u = reward + self.gamma * self.q[next_state].max() * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 841 / 100 = 8.41\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZwU1bn/8c8zK8M6AzPAwAzMIPsmywAiioDIaoLrT9S4xVyiEY1bFGMSvSoJN95cs7lckx/GXGOImhi5StyiJhoXFkUEFRlZZBWQTdlhzv2jq3u6Z3rWnl6m+/t+vfo1VadOVz1d0/306VNVp8w5h4iIpJa0eAcgIiKxp+QvIpKClPxFRFKQkr+ISApS8hcRSUEZ8Q6gPvLz811JSUm8wxARaVaWLVu20zlXEG5Zs0j+JSUlLF26NN5hiIg0K2a2oaZl6vYREUlBSv4iIilIyV9EJAUp+YuIpCAlfxGRFKTkLyKSgpT8RURSkJJ/lH3y+ZcsXrcLgJ1fHeb5lVur1VmxaQ8fbNrL3gNHqahwvPrxdj7csg+APQeO4B92e9+hoxw7XhF43vqd+3ljzc4GxbN7/xGeWLKRw8eO11jneIVj74GjbPhiPxUVjvLtX/L8ym2BOA4fO87D//yU4OHADx09zv7Dxzh45DgHj9S87qr2HjjK9n2H2Lb3EIeOHmf3/iMcPV7Bl4eOcuDIscB6P993KDDf0NcLsPdg6L7zL3PO4Zxjz4EjgfJd+48Enue3/ctD7Np/hC8PHa33to8dr+CJJRs5XhE6bPrmPQd5dfX2kLLX1+xg3c79fLR1H8s27OLNT3dSvv3LGtdddV8cr3D8aclnHDtewRdfHa5W3/8/Bfjj4s/Y8MV+HnitnP2Hj4XU27jrAM8s3wyEvvfqEry/Pt62j/tfLeeTz79k295DbNlzMLAPDhw5xta9BwN1V27ey5vlO3lq2aZq+xzgzfKdfLrjKwCcc2HrNKWKCsejb66v9f+8/3D1fb9x1wGOeu+vp9/bxMZdB1i+0fe5Dufo8Qr2eds4dLTyM3PkWEWD3mORsOYwnn9ZWZlLlIu89h06igFtWmSGlG/cdYDF63axde9BLhjRjedXbuWEjq256DfvAPCLmUP47evr+GDzXv539imUFrTi16+Uc+RYBfP/tS6wnu+MO4EHXvsUgFN65vNG+U5mje3BgSPHeOztz8jOSGP1PVP5+0efc+Wjvn1yy5Q+XFBWTKvsDM576E3W7tjPucOKuHZCT9q3yuLA0eN869GlTBvYmTv/98PAtn4xcwhvr/2CQ0cruH5iL06797U6X3+7nEz2HvS9OVtlpbPqrim8uno7VzyyJKTe/1w5kr+8u5nnV25jwayTyG+TTac22Xz+5WFaZqZTNvflaknRr6BNNju+rJ7A/C4cWczC5Vu45+yBjCrtwMnzXgHg3GFF/PndTV6dbmzafYDX1+zkqatGc95DbwHwywuHct0f3wusa2K/jizfuIedXx1h+qBCAJ77wPcFfUb/TrTOzuDp9zaHbP+xK0dx+18/YMMXB3jphrGYwdznPuJYheO84UV8d8FyAIrycti0+yBjexewavNepg7qzKWjS5h03z8BmDqwM2u2+xJbufe3qr6d2zBrbA8efWsD72/cw/g+Bdxz9iDGeK/58W+N4qLfvhOo37tTaz75/Ct+dGZ/KpzjqWWb+Hhb5ZfID6b3457nPgrZxvcm9+HRN9czorQ9z62o3jh59JsjuWz+Yq4edwK9OrbmxifeZ2RJey4a1Y33N+3hkX+tB6BHfivmTO3LrP9ZFva13Pm1/iHvv665OWzeczCkzqyxPTh09Di/fyv02qRlP5jIL/6+JlA+pmcH7j3vRA4cOcYrH29n9bavuHVqH0bO/Tu9O7Xm/OHFzF1U+TrvPW8w33tqBddO6MmIkvbkZKUDkJOZzpm/eiNQLyPNOOa9Ly8a1Y031uxkXJ8C/vDOZ9xz1kBaZ2dwbdD7B+DE4lze37iHllnp3HveiVzz+Lshyx//1ij6FbYlI9247o/v8erqHYFl/3HuIG798wcAPHfdKUz/5RtU9eINY+ndqU3YfVoXM1vmnCsLu0zJP7yKCsdv31jLhSO7hST6kjnPAbB+3vSQ+kPuepE9B2LzjS0iqaVqvqmv2pJ/3Lp9zGyKma02s3IzmxOvOGryysfb+fGij7n72Q9rrffPT3bw/MqtSvwi0qzEZWwfM0sH7gfOADYBS8xsoXOu9kwbQ8s37gHgqyp9olVdOn9xLMIREWlS8Wr5jwTKnXNrnXNHgAXAjDjFAvgOupTMeY6rH1vGV4eP8etXy2utP+Fnr3Heg2/GKDqJloe+MTzeIQSM71PA764Y0ejn9+3csH7he88b3OhtNcSpvfJjsh2AwUXtAtPj+4QdzJKsjDTS0yyk7PKTS6rV69u5DReOLKZ87lQm9O0YsuxfcyYwY0iXkLIz+neqMa6sjDSmDy7kkctHMK6GuJb/6Iyw5f+aM6HG9UYiXsm/K7AxaH6TVxZgZrPMbKmZLd2xYwfRts87iPm3ldu4tsoBm3DW7tjP0g27ox1Wyqr64WyMx64cVWedKQM7h8xfOrp7jXVnje0BwIAubcltmVljvXB+dv6JddY5qUcHxvXpyLUTegbK/Ilz2qDO1ep/fPeUwPS0QZ352oldqtUJNmVA5ToGF7Xj/LLianVmjqgse/XmcSHLHrx4GPmtswLz1wXFGayVdzAVfH3Vv7pwaGC+KC+HE4tzA/NtW2Tw4MXD+La3b2+b2pc1c6cGlj9y+QiW/mAio0rbh2yjpbeNp79zckh/+MLZp3Bqr3zSDB6s4Yv9k3um8umPpwXmH/rGcO74Wv9q9Z6/fiw/OWcwGelpzL98RMh2uubm8IuZQ3nuulMA6FfYlt9cGtq1Hlz/8W+N4v6LhjG+b8eQL5ohQfsit2UWk6p8gaz7yTS65uaEfR2RilfyD/fJDjny7Jx72DlX5pwrKygI/00ZLR9trfkUO4mNZ64ZEzJ/14wBNdb96bm+Fux9F4Qm2FOCWpzfPq1Htee1aeHr9bx5Um8AFsw6ibtmDOS603sBVGuh+b+Qpg0qJDsjLRDniJK8sMkj2GlhWnvByeGRK0bwrVN9Md40qQ9vzpnAoutO5eJRvi+jH55Zuf42LTK4eVJvWmRWJtn7LxrGab2rb+P5608FoEOrLB66ZDhPXjWa9++YxMLZpwTqdGnXgt9dMYI1c6cy79zBLJw9husm9KQ0vxULZ4/hP88/kR+d2Z+pgwrJa1mZ/Md7reHZ43tyyUm+OC8a1Y3ld0zi7KFdyUz37a/coOe8fst4/nL1yTxyxQjW/WQaK+6czNRBhXxnXE9mjijmktHdyUxP45lrxrDk9omM79uR/NbZ/Onbo1kblLCfvGo0V487ISR5+v3+myP59MfTaJGZXu1AafAXUZrBOUO7MmVgZ8yMf3xvHC/fOLba+oK9dvM4Xr7xtMB8ZrrvfZDlvR9+W+ULYKT3peWvBzCmZ+X7suqvteBlw7rlYhZ5I6gm8RrPfxMQ3OwoArbEKRaf6O1jaYTi9i3pV9iWj7buo1PbbC4dXcLEfp04+4F/8fm+ytNA/R/u/+e1WPsVtmXLnoN89sUBwPfT/9XVO+jWvmW1bfiT+dXjejKytEPgg3rjGb258QzfF4L/7K7HvzUqcG5+RprhP0u1c7sWPHnVyQD8u3caY/CppX6ts30fte4dWrLBiw18ySQrI40uVVp3XXJz6JKbQ/8ubQOv8aJR3Xj8nc+4ZUrfQLJ99eZx5GSmY2YM7NqO9fOms/Orw5Td8zIPXDyMvp3b8uINY+nQypeAR5SEtqDDnUUyuCiXwUW51aYBfntZWeCU4KHd8njyqtEMLc4lIz2Nu88aGKh33wVDuO+CIYH5P199Mh9s2oOZkW4wvk9oN0q7lpnMO7eyG+rEMEk9LejX4IAu7RjQpbKL5/VbxgemqybMF28Yy7MrtnJa7wKGd88LlK/9Sehr796hFQBzzx7ICQWtq20foCS/Vch8r46tuWb8Ccwc0Q2Aif078eRVo/loq+86nfsuGMJjb29gUNfKWDPT03jw4mFc/Yd3uXVKXxYsqewEuWBEMeu/2M+NZ/Sudjp5U4tXy38J0MvMSs0sC5gJLIxTLNW4oB8hn+87zAurtrFMXTxR89drxjB7fGUXgpnveoIfTu8HEEiMXXJzuONrlb8A/F0Fwfp2bsuEvp24fEwpAJ3atgAgLSgh+PvV/SXpaRZI/DU5uWc+Zw8tAmDygM74z5AO1zAbVNSOZ6+tbFm/evO4QCv0ylNKQ+qW5Leqlvhr4j8tOz1oo6X5rejcrkVIvfzW2ayfN51p3nULvTu1oUPr7Hptoy7dO7Ri/uVl/P/LfC3cESXtyUivO40M754X+J9EQ3H7lhSH+YIH3+u/8YzeIYm/NheP6s5JPTrUq66Z8b3JfUO2PaKkPZeOLgF83UO3Tukb8sUFMHVQIevnTSevVRZPXjWaX8z0fVG2yEznjq8NiHrihzi1/J1zx8xsNvACkA7Md86tikcsfhbU9A9uWS7bsJtv13DRikTupjN6M6Q4lyHFuYGD7P7/RFGe7wMV3EqcNqiQl288jZ4dw7fMqhrWPY8FSzbSu5OvvhkM9FphafX4SX3WkC609rqHglvh/kRsYX4ypnut8BdvGMuqLXspDWotXjiyGz96pnFv9QrvAuUmOBwSkQl9az6wGW03ndGbN8obdlV7oqv6ayxW4nYbR+fcImBRvLYvia9bh5a8OWcCnduGtmzrm/gBzh9exJie+XTNzeHBi4fRt7AtFf7EXY/k//OZQ8OW+38bhkvE/i+V3p3aVLsyMzM9jR+e2Z+l63fV+zX4Vcbd4KcmjWtP78W13jEZiYzG9pGEE5zcu+TmVPvJ3BBmFjhbYuqgQkrzW9XaZVNftX2B1LXeK08prfFMlNqM834BDQzqPxZprJRP/hUVjhdWbYt3GAnn4lHdorp+/9k34RLl4/92UlS33dbrT/UfNG2MRy4fwbnDisjNqeybrexaik7TfPrgQj68a3LIgU6Rxopbt0+i+OOSz7j96ZXcMqVPvENJKNHIX7+7YgSXewPATerfmf/+x1rGBp2euPbH0zCLXvL0y8lKD2yrsYZ2y2Not9ADiAtmjQ6MQBktLbNS/iMrTSQlW/4fbd3HO2u/AGDb3kMAfO79FZ8+ndvWuvwm71TI+uqamxPotjixOJfh3fNYP296yGmEaWkW9cQfzW21b5UVt4N3Ig2Vks2Iqb94HQg9xzlWSae5+MaobvzwrytrXN7Q3eW/RL2xoxOKSNNKyZa/1K2uL0N9WYo0bymd/EvmPBe4Eu93b66PbzAJ4ooxJbz7w/ADTAVT7hdp3lI6+QO8/NH2uiulkDYtMmnfKqvOevW5QEpEElfKJ38JFS6lB4+L0zkwXELN6/jgzkmB6b/fdBp/vnp0U4UnIk1EyT9FjQ0zAiSEb9H7xzqaPb4nE/p1rLGeX/C4JCcUtGZ4d50BI5JolPxTVE0XcYVr0fuviM1MTyPLG8QreIjaP80Kf1FWuJE0RSQxKPmnqMkDOoftjqltKAUzuHFSb2aN7cE5w3z33kkzGBVmBMSXbhjLwtljqpWLSGJQ8k9h9e2OCe7hadsik+9P6xe4eYX/lM+7q9xspVenNiE38RCRxJKSF3lJzYLvgep3Ss98/rh4Y8jBYP9Qxv6yS0aXMGNoV44cq4h+kCISMSV/CRhZ2p5Te1U/EOw/9TPcMd6qvwpEpHlQt48EVB03vzYu9JbLItLMqOUvAVVb9n/77qlkpBlPv7e5Wt2MNF+74cpTqt9KUUQSn5J/Clk/b3rghuT10a/QN7Knv40fPJ5PepppkDaRZkzdPlKnXt6dtXoE3YtWRJo3tfylTmcP7Uqvjm0YFOZMIBFpntTyl4CaLu8yMyV+kSSj5J8i5kztC/hO5xQRUbdPkuiam8PmPQfDLvvJOYO4cKRvLJ/HrhzF4WPHYxmaiCSgiFr+Zna+ma0yswozK6uy7DYzKzez1WY2Oah8ildWbmZzItm+VGqXU78LrLIy0kJG3RSR1BRpt89K4Bzgn8GFZtYfmAkMAKYAD5hZupmlA/cDU4H+wIVeXRERiaGIun2ccx9B2Pu5zgAWOOcOA+vMrBwY6S0rd86t9Z63wKv7YSRxiIhIw0Srz78r8HbQ/CavDGBjlfJR4VZgZrOAWQDduoUfe76hNu46wLuf7W6SdYmINGd1Jn8zexnoHGbR7c65Z2p6WpgyR/huprCDxDjnHgYeBigrK2uSgWTOfuBNdn51uClWlXBO7Z3Ph97N6BurvscNRKT5qzP5O+cmNmK9m4DioPkiYIs3XVN51CVT4v/e5D7c+8LqwPwtk/vy3/9YC8Di209n5Ny/12s9r9x0Giu37GPHl4e5cGRx3U8QkaQQrW6fhcDjZvZfQBegF7AY3y+CXmZWCmzGd1D4oijF0Ox1btuCbfsOhV1W0Do7ZD496A5cHdvUf3TOHgWt6VHQunEBikizFVHyN7OzgV8BBcBzZrbcOTfZObfKzJ7AdyD3GHCNc+6495zZwAtAOjDfObcqoleQxGq5R3rNl+OKiNRDRKd6Oueeds4VOeeynXOdnHOTg5bNdc6d4Jzr45z7W1D5Iudcb2/Z3Ei2nyp+ML1fg5/zj++Na/pARCRppMQVvnc8szIwPHFzlBHupup1HALv3qFyBM70Wn9CiEgqSonk/+hbG+IdQkQae6rTijsn8etXyjlraNe6K4tISkmJ5N9c1dZer89tFNu2yOT70xreZSQiyU+jeoqIpCAl/wQWZtgMEZEmoW6fBJTbMpM9B44G5l2YHp5wZYuuO5UjxyuiGJmIJAsl/wT0t++eyqbdB7l+wXIAOrTOqlYnXI9//y7N94wmEYktJf8EVNguh8J2OYH5Yd3yaq1/QZmGZRCRhlGffzN3QVkx/3He4HiHISLNjJJ/Ari3Ecnb3+evY8Ii0hhK/gng/LJi8lr6hlP+2fkn1us5/vP8lfxFpDGU/BPM+L4dq5XVnuCV/UWk4ZT8E1hJfkvAd9N1EZGmpLN9EkS4UzcfuGg4SzfsCjs+f7jz/EVE6ktNygQT3InTrmUmp/frVHt99fqISCMo+SeI757eC4BW2fX7MaaGv4hEQt0+CeKKMaVcMaa0znr/O/sUVmzeQ4X/VM8oxyUiyUkt/2ZmUFE7Lh7VPd5hiEgzp+TfXOmIr4hEQMm/mfKnfh3wFZHGUPJv5ky9/iLSCDrgG2ef/nhao5533vAi3t2wm+sn9mriiEQkFSj5x1l6WuNa7i2zMvj5zKFNHI2IpIqIun3M7F4z+9jMVpjZ02aWG7TsNjMrN7PVZjY5qHyKV1ZuZnMi2X4qeef7p/P6LePjHYaIJIlI+/xfAgY65wYDnwC3AZhZf2AmMACYAjxgZulmlg7cD0wF+gMXenWlDp3atqC4fct4hyEiSSKi5O+ce9E5d8ybfRso8qZnAAucc4edc+uAcmCk9yh3zq11zh0BFnh1RUQkhprybJ9vAn/zprsCG4OWbfLKaiqvxsxmmdlSM1u6Y8eOJgxTRETqPOBrZi8DncMsut0594xX53bgGPAH/9PC1HeE/7IJe7WSc+5h4GGAsrIyXdEkItKE6kz+zrmJtS03s8uAM4HTnQtcdroJCL6reBGwxZuuqTwqXly1LZqrFxFpliI922cKcCvwdefcgaBFC4GZZpZtZqVAL2AxsAToZWalZpaF76DwwkhiqMtzH2yN5upFRJqlSM/z/zWQDbxkvnEG3nbOXeWcW2VmTwAf4usOusY5dxzAzGYDLwDpwHzn3KoIYxARkQaKKPk753rWsmwuMDdM+SJgUSTbFRGRyGhsHxGRFKTkLyKSgpT842jaoHBn0IqIRJ8GdouTJ68azbBuefEOQ0RSlFr+cZKdkdboET1FRCKl5C8ikoKSPvnrVrciItUlffIXEZHqkj756wbnIiLVJX3yb2ojSnSGjog0f0mf/BO1z9/CjnotIhIbSZ/8E1XXvJx4hyAiKUzJPw7+7dRS2rfKincYIpLCkj7564CviEh1SZ/8RUSkuqRP/ol6wFdEJJ6SPvlv3Xsw3iGIiCScpE/+7362J94hiIgknKRP/hUJ2O9jOgotInGW9Mk/AXM/LhGDEpGUkvTJP1HktcxkygDduUtEEoOSf4xkpqcxtFsuoG4fEYk/Jf8YUb4XkUQSUfI3s7vNbIWZLTezF82si1duZvZLMyv3lg8Les5lZrbGe1wW6QsQEZGGi7Tlf69zbrBzbgjwLPAjr3wq0Mt7zAIeBDCz9sAdwChgJHCHmWmMZBGRGIso+Tvn9gXNtgL8p7HMAH7vfN4Gcs2sEJgMvOSc2+Wc2w28BEyJJIZENvfsgfEOQUQkrIj7/M1srpltBC6msuXfFdgYVG2TV1ZTebj1zjKzpWa2dMeOHZGGGRcXj+oemNb4/SKSSOpM/mb2spmtDPOYAeCcu905Vwz8AZjtf1qYVblayqsXOvewc67MOVdWUFBQv1eT4HR2v4gkijqTv3NuonNuYJjHM1WqPg6c601vAoqDlhUBW2opbzauOu2EBtVvl5MJwKyxPQJl+g0gIvEW6dk+vYJmvw587E0vBC71zvo5CdjrnNsKvABMMrM870DvJK+s2Ti9X6cG1c9t6Uv+E/p2jEY4IiKNkhHh8+eZWR+gAtgAXOWVLwKmAeXAAeAKAOfcLjO7G1ji1bvLObcrwhiaBVfDtIhIPESU/J1z59ZQ7oBralg2H5gfyXbjrU+nNqz+/Mt4hyEi0mi6wrcRfnXR0AY/x2qYFhGJByX/RtCgnCLS3EXa5y9h/Pclw/nq0DEg9ItCXxoikiiU/KNgsoZuFpEEp26fGNLIniKSKJT8Y0SJX0QSiZJ/DKnPX0QShZJ/I7gGXKbVo6AVANkZ6ZWF+hUgInGmA75R9ssLh/Luht10bteislC/AEQkztTyr0ObFpF9P7Ztkcm4PhrXR0QSi5J/I0Q8Nr+6fUQkztTt0wg19fkv/cFEjhyriHE0IiINp+Rflwb0z+e3zo5eHCIiTUjdPjHUkLOERESiScm/DuHSdcc2LcKUiog0H0r+jdC+VVajnqebuItIolDyjyF1+4hIolDyj9D6edMb/Bz9AhCReFPyr4MLGpAnPU1JW0SSQ0on/3nnDKp33R9M78cL149tku2q+0dE4i2lk//Mkd3qXff8smJ6dmwdxWhERGInpZN/QzTlePzq8xeReGuS5G9mN5uZM7N8b97M7JdmVm5mK8xsWFDdy8xsjfe4rCm2H03qoBGRZBTx8A5mVgycAXwWVDwV6OU9RgEPAqPMrD1wB1CGL68uM7OFzrndkcYRbU3RVtfNXEQkUTRFy/8+4BZCG8kzgN87n7eBXDMrBCYDLznndnkJ/yVgShPEEDXRSNi6paOIxFtEyd/Mvg5sds69X2VRV2Bj0Pwmr6ym8oRnTZix9QtAROKtzm4fM3sZ6Bxm0e3A94FJ4Z4WpszVUh5uu7OAWQDdutX/rJymNKFvR9769IsmW59a/CKSKOpM/s65ieHKzWwQUAq877WKi4B3zWwkvhZ9cVD1ImCLVz6uSvlrNWz3YeBhgLKysqRoK6vFLyKJotHdPs65D5xzHZ1zJc65EnyJfZhzbhuwELjUO+vnJGCvc24r8AIwyczyzCwP36+GFyJ/GQ330DeG11nHapiOlH4BiEi8RetmLouAaUA5cAC4AsA5t8vM7gaWePXucs7tilIMtZoyMFxPVnW6GldEklGTJX+v9e+fdsA1NdSbD8xvqu3GSlO01rvk+u4D0DU3J/KViYhEQLdxrEVtbf63bpvA7v1HG7S+s4Z0Ja9lFqf1LogsMBGRCCn511PVIRkK2+VQ2K5hLXgzY1yfjk0ZlohIo2hsn1oYOkNHRJKTkn896QwdEUkmSv51UMNfRJJRSvb5nzOsckSJk3q05+21kZ1t+tRVo2mVnZK7UkSaqZRs+f/neScGph+42HexV1Z643dFWUl7+hW2jTguEZFYScnknxZ0L950rzM/OzMld4WIpChlvLqo019EkpCSv6euk3l0to+IJBMlf0+4Br6ZxvYRkeSk5K8WvYikICV/vzANfF3dKyLJSsm/nqqO7SMi0pwp+fsZzBjSJbTI1PoXkeSk5O/n4JxhRTUu1tk+IpJMUj75Byf1cPm9XU5mzGIREYkVDUjjqal3589Xn8zra3aQGcHwDyIiiUbJv1ZGSX4rSvJbxTsQEZEmpeasxwjXr6+jvSKSnJT8PUrzIpJKUj756yQeEUlFKZ/8g+W3zgage4eWXom+GkQkOUWU/M3sTjPbbGbLvce0oGW3mVm5ma02s8lB5VO8snIzmxPJ9ptav8K2/OU7J3PrlL7xDkVEJKqaouV/n3NuiPdYBGBm/YGZwABgCvCAmaWbWTpwPzAV6A9c6NVNGMO65ZGRpha/iCS3aJ3qOQNY4Jw7DKwzs3JgpLes3Dm3FsDMFnh1P4xSHPXmgsZx8J/aOfqEDvEKR0Qkqpqi5T/bzFaY2Xwzy/PKugIbg+ps8spqKq/GzGaZ2VIzW7pjx44mCDM8CzNuQ+9ObXjrtgl8c0xJ1LYrIhJPdSZ/M3vZzFaGecwAHgROAIYAW4Gf+Z8WZlWulvLqhc497Jwrc86VFRQU1OvFNKXCdjlhvxhERJJBnd0+zrmJ9VmRmf0GeNab3QQUBy0uArZ40zWVx0V2hu/777KTS+IZhohITEXU529mhc65rd7s2cBKb3oh8LiZ/RfQBegFLMbX8u9lZqXAZnwHhS+KJIZIZaanse4n0+quKCKSRCI94PtTMxuCr+tmPfBtAOfcKjN7At+B3GPANc654wBmNht4AUgH5jvnVkUYQ8TUvSMiqSai5O+cu6SWZXOBuWHKFwGLItmuiIhERlf4ioikICV/EZEUlPTJv7h9TrxDEBFJOEmf/Du1aRHvEEREEk5SJ3/nHEs37I53GCIiCSepk//xCt2iRUQknKRO/iIiEl5SJ3+1+0VEwkvq5F/hlP5FRMJJ6uSv3C8iEp6Sv4hICkru5F+l179rbg6v3HRanKIREUkcyZ38q4BFDDoAAAeiSURBVLT8+3dpS4+C1vEJRkQkgSR18q96wFcDN4uI+CR18q/a5a9h+0VEfJI7+VfJ/qa2v4gIkPTJv0q3j3K/iAiQ9Mk/dF7JX0TEJ6mTf/UDvsr+IiKQ5Mm/2jVeyv0iIkCSJ//2LbNC5pX7RUR8kjr5p6WFpntTp7+ICJDkyb8qpX4REZ+Ik7+ZXWtmq81slZn9NKj8NjMr95ZNDiqf4pWVm9mcSLffsFhjuTURkcSVEcmTzWw8MAMY7Jw7bGYdvfL+wExgANAFeNnMentPux84A9gELDGzhc65DyOJo97xxmIjIiLNQETJH7gamOecOwzgnNvulc8AFnjl68ysHBjpLSt3zq0FMLMFXt2oJ//xfQq4eXKfaG9GRKRZiLTbpzdwqpm9Y2b/MLMRXnlXYGNQvU1eWU3l1ZjZLDNbamZLd+zYEWGY8MgVIynKaxnxekREkkGdLX8zexnoHGbR7d7z84CTgBHAE2bWg/A9LI7wXzZhb7ninHsYeBigrKxMt2UREWlCdSZ/59zEmpaZ2dXAX5xvEJ3FZlYB5ONr0RcHVS0CtnjTNZVHxd1nDeTEonbR3ISISLMTabfPX4EJAN4B3SxgJ7AQmGlm2WZWCvQCFgNLgF5mVmpmWfgOCi+MMIZaXXJSdwYX5UZzEyIizU6kB3znA/PNbCVwBLjM+xWwysyewHcg9xhwjXPuOICZzQZeANKB+c65VRHGICIiDWRVhz1ORGVlZW7p0qXxDkNEpFkxs2XOubJwy1LqCl8REfFR8hcRSUFK/iIiKUjJX0QkBSn5i4ikICV/EZEU1CxO9TSzHcCGCFaRj+/is0SjuBpGcTWM4mqYZIyru3OuINyCZpH8I2VmS2s61zWeFFfDKK6GUVwNk2pxqdtHRCQFKfmLiKSgVEn+D8c7gBooroZRXA2juBompeJKiT5/EREJlSotfxERCaLkLyKSgpI6+ZvZFDNbbWblZjYnDttfb2YfmNlyM1vqlbU3s5fMbI33N88rNzP7pRfrCjMb1oRxzDez7d59F/xlDY7DzC7z6q8xs8uiFNedZrbZ22fLzWxa0LLbvLhWm9nkoPIm/T+bWbGZvWpmH5nZKjP7rlce131WS1xx3Wdm1sLMFpvZ+15c/+6Vl5rv/t5rzOxP3g2c8G7y9Cdv2++YWUld8TZxXL8zs3VB+2uIVx6z9763znQze8/MnvXmY7u/nHNJ+cB3s5hPgR747jD2PtA/xjGsB/KrlP0UmONNzwH+w5ueBvwN3/2PTwLeacI4xgLDgJWNjQNoD6z1/uZ503lRiOtO4OYwdft7/8NsoNT736ZH4/8MFALDvOk2wCfe9uO6z2qJK677zHvdrb3pTOAdbz88Acz0yh8CrvamvwM85E3PBP5UW7xRiOt3wHlh6sfsve+t90bgceBZbz6m+yuZW/4jgXLn3Frn3BFgATAjzjGBL4ZHvelHgbOCyn/vfN4Gcs2ssCk26Jz7J7ArwjgmAy8553Y553YDLwFTohBXTWYAC5xzh51z64ByfP/jJv8/O+e2Oufe9aa/BD4CuhLnfVZLXDWJyT7zXvdX3mym93D4bvH6lFdedX/59+NTwOlmZrXE29Rx1SRm730zKwKmA7/15o0Y769kTv5dgY1B85uo/YMSDQ540cyWmdksr6yTc24r+D7MQEevPNbxNjSOWMY32/vZPd/ftRKvuLyf2EPxtRoTZp9ViQvivM+8LozlwHZ8yfFTYI9z7liYbQS27y3fC3SIRVzOOf/+muvtr/vMLLtqXFW2H43/48+BW4AKb74DMd5fyZz8LUxZrM9rHeOcGwZMBa4xs7G11E2EeKHmOGIV34PACcAQYCvws3jFZWatgT8D1zvn9tVWNZaxhYkr7vvMOXfcOTcEKMLX+uxXyzbiFpeZDQRuA/oCI/B15dway7jM7Exgu3NuWXBxLduISlzJnPw3AcVB80XAllgG4Jzb4v3dDjyN70Pxub87x/u73ase63gbGkdM4nPOfe59YCuA31D5MzamcZlZJr4E+wfn3F+84rjvs3BxJco+82LZA7yGr88818wywmwjsH1veTt83X+xiGuK133mnHOHgUeI/f4aA3zdzNbj63KbgO+XQGz3V6QHLRL1AWTgOzBTSuVBrQEx3H4roE3Q9Jv4+gnvJfSg4U+96emEHmxa3MTxlBB6YLVBceBrIa3Dd8Arz5tuH4W4CoOmb8DXpwkwgNCDW2vxHbhs8v+z99p/D/y8Snlc91ktccV1nwEFQK43nQO8DpwJPEnoAczveNPXEHoA84na4o1CXIVB+/PnwLx4vPe9dY+j8oBvTPdXkyWXRHzgO3r/Cb7+x9tjvO0e3j/mfWCVf/v4+ur+Dqzx/rYPeiPe78X6AVDWhLH8EV93wFF8rYUrGxMH8E18B5XKgSuiFNf/eNtdASwkNLHd7sW1Gpgarf8zcAq+n88rgOXeY1q891ktccV1nwGDgfe87a8EfhT0GVjsvfYngWyvvIU3X+4t71FXvE0c1yve/loJPEblGUExe+8HrXcclck/pvtLwzuIiKSgZO7zFxGRGij5i4ikICV/EZEUpOQvIpKClPxFRFKQkr+ISApS8hcRSUH/BzPOnoahsuxMAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = QLearningAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 4000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_qlearning(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_qlearning(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 双重 Q 学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DoubleQLearningAgent:\n",
    "    def __init__(self, env, gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        self.gamma = gamma\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epsilon = epsilon\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q0 = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        self.q1 = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def decide(self, state):\n",
    "        if np.random.uniform() > self.epsilon:\n",
    "            action = (self.q0 + self.q1)[state].argmax()\n",
    "        else:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        return action\n",
    "    \n",
    "    def learn(self, state, action, reward, next_state, done):\n",
    "        if np.random.randint(2):\n",
    "            self.q0, self.q1 = self.q1, self.q0\n",
    "        a = self.q0[next_state].argmax()\n",
    "        u = reward + self.gamma * self.q1[next_state, a] * (1. - done)\n",
    "        td_error = u - self.q0[state, action]\n",
    "        self.q0[state, action] += self.learning_rate * td_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 809 / 100 = 8.09\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de3wV9Z3/8dcnCQkBuV+FcAlyE7wBKSqoVVGuVtTqFlvrtT9tK921ttuF6tbWqmXb7nZr17ZLW7qr2y5aq5VttYjVum1XhaCggAIRIoRrKBAQcs/n98eZhJPk5HqSnOTM+/l45JGZ73xn5nPmzPmcOd/5zoy5OyIiEi4piQ5AREQ6npK/iEgIKfmLiISQkr+ISAgp+YuIhFBaogNojoEDB/ro0aMTHYaISJeyfv36Q+4+KNa0LpH8R48eTW5ubqLDEBHpUszsg4amqdlHRCSElPxFREJIyV9EJISU/EVEQkjJX0QkhJT8RURCSMlfRCSElPxb4M/bD7FpTxG/yt3Nr3J3U3077Nz8w7y3/xj/s3EvRcXlAFRWOU+t201FZRUVlVVsO3Cc4rJK9h4tpriskp/+aQf/8ZedFJ2M1C+vrKLgyEl++cYuDh4vAaCouJzC46U169956AR/yTtEeWUV6z84wpotByguq6SkvLJerM9t2MOHpRUUHi/lWEk5B46VsK+omJtXrCXv4HEA3J0nXv+A/UUllJRXsq+omAPHSqisOnWb7+r5yyoi8RUeL2VfUeQ1VFZ5TfwAR06UceREWc34ht1H+UveoZrxoyfLiL6F+J6jxfx5+yHyDn7I6s37eXHzfrbuj8RWdLKc3YdP8tyGPZSUV3K85NR66vqwtIKyiio27D7Km7uO8Nu39/LAc5s4cKwkZv1Ne4p4a9eRmliPnCijqsopKi6norKqpt7BYyWcKK3gWEk5R06UcayknJLySk6UVlBaUcmRE2W13ufoeA7WWXd1jNXzV29zgKoq5+jJSCzv7jvG+g8O12yfg8dKauKL5WRZBSXllTH3g427j/J2wVGqqpz7f/MOr7x3kMMnyvjd2/uAyD5aVHzqNUU7UVpRs49Xv9YPSyt49q0C9heVUF5ZhbvzndXv8cybBTz2Sh6vbisE4LX3/8qLm/fz/Dv7eHffMU6WRZbt7hQcOcn2A8dr7QdFJ8spLovEcPRkGXkHP+T5d/adml5cTnllFfmHTtSUHTlRxtGTke1fVeX85H93UFZRxd6jxfzszzs5VlLO7zft40Cw/apfU/W2enp9Qa19tbQisv6ik+X8cetBnl5fwPYDx2ttyz9tL6x5z4uKIzG/t/8Yj7+WT/6hEzy5bhelFbXfg5LySk6WVfB/7x9i/QdHYr53AO/tP8YXn9wQ8z1uL9YV7uefk5PjneEir9FLfldr/NYZo3l1WyE7o3bKuu6+7Az+8/8+4MM6H66m/PMN5/KlX21sdv0xg3qyo7B2HP17pnM4agdvqWunDOfZt/a0ev5q/+/ibH7yp50149dPy+Lp9QVxLzdaRloKpRVVTVdsR2MG9WTCkF68sGk/AOdn9+crcyfwg5fz+OPWQob3zWTP0eJa85w1vDeb9hxr8bqG981kZP8evLbjr/WmDe6VQd8e3dh24MMWL/eS8YP43yCJt6XJw3qzee+p1zllZF/e2nW00Xm+On8iH5ZW8ugftteU1d2GI/pnsvtwcazZm+2uj47h31/dEXPakN4ZHDhWGnNaY26YlsWvYuzj104Zzj99/ByWvfAeK/6yM8acEdHvw/XTsvjuDee2OAYAM1vv7jkxpyn5N1/d5C8i0hHyly1o1XyNJf+ENfuY2Vwz22pmeWa2JFFxNGTL3mO1mj9ERJJJQpK/maUCjwHzgEnAjWY2KRGxxLJpTxHzH/0TP3h5e9OVRUS6oEQd+U8H8tx9h7uXASuBhQmKpZ59RZGTde8UFCU4EhEJux/cOKVdlpuo5D8c2B01XhCU1TCzO80s18xyCwvb/gSUu1PWwAnC6vMgFWr2CZWMtI77ODy4cHKz6i08bxjbHprHzm/N59MXjGq07utLZ7Uohje+OosrzhzconkAVi2eWWv8/Ufmc+2UWh9fLh43sNb4zLEDWryePpndao1/f9F5MesN75sJwP0LzmxwWfnLFpC/bAHvPzKfi8YObLBeW7v3yvGtmm/nt+bXDH/s3GFtFU4tiUr+FqOsVqZ19+XunuPuOYMGxbwddVzuemI94+9/oVaXvNKKSHe5/L9Ges282g69HqRj1E3kP/zU1Ebr33FRNlsenFszPrhXRq3pv/3CRW0XHDDjjEgy/MbVk/nTVy6rKf/pzafOzZ01vDffXzSF9LQUzIxvXnMW37n+HH5w4xTW339FreXdOH0kQ/t0Z/vD82Ku7/Q+3WuNP3XXhQzp3Z2f3JzDp84fWVN+x0XZ9ebNGdWPV758Kdsemsfjt0/nnKy+/OsnTiXi1BTje584j9mThgCw5ouX8E8fP4ebLhjJv/zNuXzvE+fyX3ecz6fOH8k3F07mrkvG1FvHE3dMJ3/ZAuZMjizjV5+9kI0PzI698epYc+8l5C9bwGcuHsP3F53Hv/zNuZw3om/N9B9FvfepKca3rju7ZvzHN02r2Teun5bFjdNH1Ft+9bJun3lq25x5eu+a4VtnjG4wts9cnM39C86kV/c0HvvkVH5809Ra66+2+p5LaoYfvvYszKzmC6u9JKS3j5ldCHzd3ecE40sB3P1bseq3R2+f6p47v/7cDKaN6gfAzGUv1+uKF73x1dunY11z3jB+s2FvvfIHF07ma89tjjlPikGVQ3pqCmVRfe+r38dn3yrg1+v38KXZ47n2h/8HRD68X786ciRe/R6/8dVZnP/IH2rNf7Ksgr9buYE1Ww7U6gKYv2wBOQ+9xKEP63cJ/P6i87j3qY3cOmM0P/vzzlrLi/ZvL2/nuy9uI3/ZgpoY3n1wLpnpqY1uo9z8w5SUVzFz7ADMIsdU6z84THFZFe/tP0b+X08w68whjBt8Go/+YTsnyir58uwJZA/sWWs5X3tuE7sOn+Q/bptes/7f3D2TVRv28oXLx9KvZ3qt+kUnyzn3wRf5t09O4apzWndk+tauI/RIT6NX9zSGBUfvde09WsyMZS8DkW259Jl3+OT0kTyVu5vSiioWXzaWL8waV2++giMnefQP2+mT2Y0l884kNSXW8eYpb+46wuRhvclIS+VEaQXvF37IvuB6hrOG9eHp9QXce+V4Xt1WyNHiMq6dkkVxWSXfXv0eX549AYCvPP02N+RkMWpATzLSUqiodEYO6BFzfXuOFlNV5ZRXVtEzI40hvbvzTkER2YN6clpG2z1mpdN19TSzNGAbMAvYA6wDPunuMT/RHZX8YyV3Jf/EufOSMSz/3/r9r/OXLeDwiTLueiKXdfm1L5zZ/vA8xt33AoN7ZXDBmAGs2riXv718LPcGH9Bo6z84wotb9vPFK8bTvVskyVa/x1sfmsuE+38PRI4W339kfr35q+vmL1tAUXE5n//FekrKq3jomrOY9/0/1UyrW3/ZdWezaPrIesuLtdxEuPPxXMzg3z8dM2d0uOrt8eznZzBlZL8ER9O1NJb8E/IkL3evMLPFwGogFVjRUOLvDFau3VXzBSEdI7NbKn+Tk1Uv+Ve3Uffvmc6Ifj3qJf9uqSksu+5sZpwxkCdezwfgtO6xd/Npo/rVe18vHDOA13b8lehjouif5NFun5nNvqLIL8U+md34xWcuaNZrayzxdwbLb+4cSb/aZz96Bi9s2qfE38YS9hhHd38eeD5R64+2r6iYxn4ALXnmnSZ/NkrzpaVYzJPpD3xsEt/4ny30ykjjnW/MAWDtfbOY/nCk+aWxI+El8yayJbiCtG5ybcmP25/f9hGOl1SQFrzfn8gZwdjBp8Ws+7WPNdw7+f4FZ/LmriMNTpfmWzJvIkvmTUx0GEmnSzzDt71d+K2XG5wWfa8eadq2h+Yx/v4XGq3z2tJZfOThl2qVVSf2j0/LonvaqXbuwb268/t7LuZYcf3bYwyNOon52Y+eUW96dRt4S3TvllrTBLTxa7PpmdF4m3tDPnNx/ZOaT911Ib0z9ZGTziH0N3Z7/LX8Rqff/LM3OiSOZHDvleNJT0th20Pz+M3dMxusN6hOT5povbt3I71OT52JQ3szPbt/vbr3XNF4N7pLx0d6iZ0/puXdDAH69OhGWmrbfUSmZ/dn4tDeTVcU6QChS/4nSiu47edra8afi9GbJNpGXegV08ShvWqGq7vA9Qh6pqSnpdTrWljXjkfm8/t7Lo4rhvS0FL5w+Vie+fyMmNNnjB3Ijkfm1+r2JyIRoUv+a7Yc4JWt6r8fr+h2dA8u0YhuZhnSu3vMo/VqKSnWJkfBX5o9gamNnAhM0bkakZjUACmt8oVZY1n8y7eAU18EddPszDMGsnbnYYb3zeQ/b59O/qET1D1z8pcll9ecXBWRjqPkL61y1TnDopJ/9ZF/7Tp3fXQMZpGTselpKTF7zQxv4OIeEWlfSv7SYkN7R9rzV9yaw+Be3XkqN3KbprrH7927pfK3Ma6+lMb95OYcfSlKuwtdm38rev+FTt2bctVVFRzpXz5xCGcN71PTDVbXQrSNKycNYdIw9QqS9hW65C9Ne6yJm6DVbbevqIyUdGvDbpEi0r70aZV6enfv1uj0uveDqu7tU7d/voh0Xmrzlxare7uEf5g7kYy01Fbf3VFEOp6Sv7RYVZ3sP+C0DL55zVkJikZEWkO/06VZekbdV/4fr+o0j1sWkVZS8pdmiT7Wv25qVsLiEJG2oeQvzVK3qUdEujYlf2kWdzg3q0+iwxCRNqITvtIsZvDM52fqF4BIkgjFkX9JeSXPvlVQr3+6NM/dl53Brz83g9QU04VcIkkiFJ/k767eyhef3Mgftxa26ulOyeq1pZc3q97fz5nI5GFq8hFJJqFI/geOlwJwrKQ8wZF0Lqf30c3DRMIqFMlfTtn4wOxm1WvsMYwi0vUp+YdMn8zG79tTTY8+FEluoUj+auUXEaktruRvZjeY2WYzqzKznDrTlppZnpltNbM5UeVzg7I8M1sSz/pbSp19Wi5Dd+oUSUrx9vPfBFwH/Ht0oZlNAhYBk4FhwEtmNj6Y/BhwJVAArDOzVe6+Jc44GhXdwefB/9ncnqtKKs/dPZMhwVO7RCS5xJX83f1dIFb3yYXASncvBXaaWR4wPZiW5+47gvlWBnXbNflHO/RhWUetqss7V+3+IkmrvX7TDwd2R40XBGUNlddjZneaWa6Z5RYWFrZTmFItf9mCRIcgIh2oySN/M3sJGBpj0n3u/lxDs8Uoc2J/2cRsiXf35cBygJycHLXWd4AJQ3px3dSY38UikmSaTP7ufkUrllsAjIgazwL2BsMNlUuCrf7iJYkOQUQ6SHs1+6wCFplZhpllA+OAtcA6YJyZZZtZOpGTwqvaKYZ67nlyQ0etSkSkU4vrhK+ZXQv8ABgE/M7MNrj7HHffbGZPETmRWwHc7e6VwTyLgdVAKrDC3du9+436+YuI1BZvb59ngWcbmPYw8HCM8ueB5+NZrzTPf9z2EX7wch7rPziS6FBEpJPRFTxJ7NIJg1l82dhEhyEinVAokr9u4ywiUlsokr8e4iIiUlsokn9JeVWiQxAR6VRCkfw99nVkoRDm1y4iDQtH8lf+ExGpJRTJP8xMVzmISAyhSP5h7uxz0biB3HTByFplg3tlMHPsgARFJCKdQbz385dOrltqCg9dczb/9fqumrK197Xmdk0ikkxCkfzV9NGwvj266UIwkRAKRfIPU4+XG6eP5L/X7mq6YmDD12a3YzQi0lmFos0/TL195p8defTCRWMHJjgSEenMkj75V1U5L245kOgw2s2Vk4bUGg/TF52ItF7SJ/9Ne4sSHUK7auhsRph7OIlI05I++Sf7kXB0ko8eTvbXLSLxSfrkn+wyu6XWDBsNH/HfdcmYjglIRLoEJf8u7h+vmtSsekvnn0n+sgXtHI2IdBVK/l1cn8xuTBjSK9FhiEgXo+SfBJ7+3IWJDkFEuhgl/yRQ3e5vZkwb1Y9zs/qwdP7EBEclIp1ZKK7wDZMe6Wk8t/iiRIchIp1c0h/5Hz5ZlugQREQ6naRP/rf9fF2iQ2hXeji9iLRGXMnfzL5jZu+Z2dtm9qyZ9Y2attTM8sxsq5nNiSqfG5TlmdmSeNYvIiKtE++R/xrgLHc/B9gGLAUws0nAImAyMBf4oZmlmlkq8BgwD5gE3BjUlWaaOFTdOkUkfnElf3d/0d0rgtHXgaxgeCGw0t1L3X0nkAdMD/7y3H2Hu5cBK4O60kyLPjICgMsnDuY7159DaoqafUSk5dqyzf924IVgeDiwO2paQVDWUHk9ZnanmeWaWW5hYWEbhtm1VbfxD++byQ05kS8C3cZHRFqqya6eZvYSMDTGpPvc/bmgzn1ABfCL6tli1Hdif9nEzF3uvhxYDpCTk6P8Fmjs/K5+A4hIczWZ/N290Qe+mtktwFXALPeae0kWACOiqmUBe4PhhsqlBcL0dDIRaXvx9vaZC/wDcLW7n4yatApYZGYZZpYNjAPWAuuAcWaWbWbpRE4Kr4onhrDR0b2ItIV4r/D9NyADWBO0Rb/u7p91981m9hSwhUhz0N3uXglgZouB1UAqsMLdN8cZQ7gE7T7R9+tPDcr+Ya5u6SAizRNX8nf3sY1Mexh4OEb588Dz8aw3zKqP/KMbfVJSTLdrFpEWSforfJPBkN4Z9cr0pC4RiYdu7NZFXD8ti5H9e+jZvCLSJnTk38lMOr13zPLv3nAufztrXAdHIyLJSsm/k2nqyN5itvqLiLSMkn8XU/3loDZ/EYmHkn8XECvRK/mLSDyU/BNg3X0NXzTdVFLX+V4RaQtK/gkwqFf9rpstpds7iEg8lPy7GLX5i0hbUPLvYqp7+yj3i0g8lPw7mVhdPWslejX6i0gbUPJPkK/On8hPbs6pV97c5hw1+4hIPHR7hwS585Izml3XGhgWEWktHfl3Mj3SUxudnp4Wecsy0/XWiUjrKYN0Mg98bHKj0686Zxj3XDFO9+4Xkbio2aeT6ZPZrV5Z9Eng1BTjnivGd2BEIpKMdOTfyeiWzSLSEZT8Oxn14hGRjqDkLyISQkr+IiIhpOQvIhJCSv4J1pxHM5ou7RKRNqbkn2C9u9fubatbNYtIR4gr+ZvZN83sbTPbYGYvmtmwoNzM7FEzywumT42a5xYz2x783RLvCxARkZaL98j/O+5+jrufB/wW+FpQPg8YF/zdCfwIwMz6Aw8A5wPTgQfMrF+cMXRp6topIokQV/J392NRoz05dffhhcDjHvE60NfMTgfmAGvc/bC7HwHWAHPjiaGrqL4nj4hIZxB3RjKzh81sN/ApTh35Dwd2R1UrCMoaKo+13DvNLNfMcgsLC+MNM+Ge/fyMRqffcVE2O781v4OiEZGwazL5m9lLZrYpxt9CAHe/z91HAL8AFlfPFmNR3kh5/UL35e6e4+45gwYNat6r6cQa6rFTfYLXAGvg3g665YOItLUmb+zm7lc0c1m/BH5HpE2/ABgRNS0L2BuUX1qn/I/NXH5Sq07wQ3p3rzfty7MndHA0IpLs4u3tE91J/WrgvWB4FXBz0OvnAqDI3fcBq4HZZtYvONE7OyhLWhOH9mpR/e7dUslftqBmPH/ZAj4+LautwxKRkIv3ls7LzGwCUAV8AHw2KH8emA/kASeB2wDc/bCZfRNYF9R70N0PxxlDl6bePiKSCHElf3f/eAPlDtzdwLQVwIp41tsVNdVu31B7v4hIe1D/wzj175nerHo6wheRzkTJP8FSgiN+HfiLSEfSYxwT7KYLRvHB4RMsvmxsokMRkRBR8u8gDR3ZZ6an8tA1Z3dsMCISemr2iVN33bZBRLogZa5W6NejW83w7MlD+fFN0+rVuXjcQF6696O1yn74qan8/p6L2z0+EZGmqNmnFc7PHsDvN+8HIids5541tF6dJ+44H6jdy2f+2ad3SHwiIk3RkX87GNE/s15ZS3vzZKg5SUTakY78WyFWIv/LksvZX1RCWUUVYwefFvc6/vj3l7LnSHHcyxERiUXJvxViXbA1vG8mw/vWP+JvrdP7ZHJ6n7ZbnohINLUttML4IafFHI5l2ujIg8r6ZHZrtJ6ISEfSkX8zDO6VwcHjpTXj103NYvbkyEneycN6Nzrv1z82mZsvHKWjeBHpVJT8m6FuG78ZnDW8T7PmTU9LYeLQxr8gREQ6mpp9RERCSMk/8OY/Xlmv7A9fOnWR1rsPzmVQr4yODElEpN0o+Qd6da/fAtYz/VRZZnoqmd1SOzIkEZF2o+RfxyfPH1kz3De4jcNnLhoDwO0zRwMw4DT9AhCRrk0nfAMGNc/O/eUbu4D6z9O9dWY2t87MTkR4IiJtSkf+AT1GUUTCREf+gZSo3D9mUE8+fcGoxAUjItLOlPwD0Uf+L3/p0sQFIiLSAdTsIyISQm2S/M3sy2bmZjYwGDcze9TM8szsbTObGlX3FjPbHvzd0hbrj9cdF+kkroiES9zNPmY2ArgS2BVVPA8YF/ydD/wION/M+gMPADmAA+vNbJW7H4k3jtb68U3TYj6MRUQkmbXFkf/3gK8QSebVFgKPe8TrQF8zOx2YA6xx98NBwl8DzG2DGOIQ4/7MIiJJLq7kb2ZXA3vcfWOdScOB3VHjBUFZQ+UJ0y1Vpz1EJHyabPYxs5eAWO0i9wFfBWbHmi1GmTdSHmu9dwJ3AowcOTJWlbhcNmEQr2wtbPPlioh0BU0mf3e/Ila5mZ0NZAMbg26SWcCbZjadyBH9iKjqWcDeoPzSOuV/bGC9y4HlADk5OW3eNjMwuEVDZrru1yMi4dPqE77u/g4wuHrczPKBHHc/ZGargMVmtpLICd8id99nZquBR8ysXzDbbGBpq6OPw9evnszUUf24cMyARKxeRCSh2usir+eB+UAecBK4DcDdD5vZN4F1Qb0H3f1wO8XQqJ4Zadw4ve2bk0REuoI2S/7uPjpq2IG7G6i3AljRVuttTEl5ZUesRkSky0nqri7HissTHYKISKeU1MlfRERiU/IXEQmhUCb/t2I8r1dEJExCmfz79UxPdAgiIgkVyuQvIhJ2Sv4iIiGU3Mlfj+UVEYkpuZO/iIjEpOQvIhJCSv4iIiGk5C8iEkJK/iIiIaTkLyISQkmd/E19PUVEYkrq5O+xHw8sIhJ6SZ38RUQktqRO/mr2ERGJLamTv4iIxKbkLyISQm32APeu4NW/v5Q9R4sTHYaISMKFKvmPGtCTUQN6JjoMEZGEU7OPiEgIxZX8zezrZrbHzDYEf/Ojpi01szwz22pmc6LK5wZleWa2JJ71i4hI67RFs8/33P270QVmNglYBEwGhgEvmdn4YPJjwJVAAbDOzFa5+5Y2iKMeU09PEZGY2qvNfyGw0t1LgZ1mlgdMD6blufsOADNbGdRtl+TvusBXRCSmtmjzX2xmb5vZCjPrF5QNB3ZH1SkIyhoqr8fM7jSzXDPLLSwsbIMwRUSkWpPJ38xeMrNNMf4WAj8CzgDOA/YB/1w9W4xFeSPl9Qvdl7t7jrvnDBo0qFkvpn7srZpNRCTpNdns4+5XNGdBZvYT4LfBaAEwImpyFrA3GG6oXEREOki8vX1Ojxq9FtgUDK8CFplZhpllA+OAtcA6YJyZZZtZOpGTwqviiUFERFou3hO+3zaz84g03eQDdwG4+2Yze4rIidwK4G53rwQws8XAaiAVWOHum+OMQUREWiiu5O/un25k2sPAwzHKnweej2e9IiISH13hKyISQkr+IiIhlNTJP7qn59De3RMWh4hIZ5PUyT/6AoL0tKR+qSIiLaKMKCISQkmd/HWBr4hIbEmd/EVEJDYlfxGREFLyFxEJISV/EZEQUvIXEQkhJX8RkRBS8hcRCSElfxGREErq5H9a9/Z6Pr2ISNeW1Mk/Iy010SGIiHRKSZ38RUQkNiV/EZEQUvIXEQkhJX8RkRBS8hcRCSElfxGREApN8v/q/DMTHYKISKcRd/I3sy+Y2VYz22xm344qX2pmecG0OVHlc4OyPDNbEu/6m2vuWUM7alUiIp1eXJfAmtllwELgHHcvNbPBQfkkYBEwGRgGvGRm44PZHgOuBAqAdWa2yt23xBOHiIi0TLz3P/gcsMzdSwHc/WBQvhBYGZTvNLM8YHowLc/ddwCY2cqgrpK/iEgHirfZZzxwsZm9YWavmtlHgvLhwO6oegVBWUPl9ZjZnWaWa2a5hYWFcYYpIiLRmjzyN7OXgFgN5vcF8/cDLgA+AjxlZmMAi1Hfif1l47HW6+7LgeUAOTk5MeuIiEjrNJn83f2KhqaZ2eeAZ9zdgbVmVgUMJHJEPyKqahawNxhuqLxdPHrjFPpmdmvPVYiIdDnxNvv8BrgcIDihmw4cAlYBi8wsw8yygXHAWmAdMM7Mss0snchJ4VVxxtCoq88dxiXjB7XnKkREupx4T/iuAFaY2SagDLgl+BWw2cyeInIitwK4290rAcxsMbAaSAVWuPvmOGMQEZEWskiu7txycnI8Nzc30WGIiHQpZrbe3XNiTQvNFb4iInKKkr+ISAgp+YuIhJCSv4hICCn5i4iEkJK/iEgIdYmunmZWCHwQxyIGErn4TLQt6tL2qE3b45Rk2Baj3D3mVa5dIvnHy8xyG+rrGjbaFrVpe9Sm7XFKsm8LNfuIiISQkr+ISAiFJfkvT3QAnYi2RW3aHrVpe5yS1NsiFG3+IiJSW1iO/EVEJIqSv4hICCV18jezuWa21czyzGxJouNpD2Y2wsxeMbN3zWyzmf1dUN7fzNaY2fbgf7+g3Mzs0WCbvG1mU6OWdUtQf7uZ3ZKo19QWzCzVzN4ys98G49nBs6a3m9mTwcOECB449GSwPd4ws9FRy1galG81szmJeSXxM7O+Zva0mb0X7CcXhnX/MLMvBp+TTWb232bWPbT7hrsn5R+Rh8W8D4wh8oSxjcCkRMfVDq/zdGBqMNwL2AZMAr4NLAnKlwD/FAzPB14g8pzlC4A3gvL+wI7gf79guF+iX18c2+Ve4JfAb4Pxp4BFwfCPgc8Fw58HfhwMLwKeDIYnBftMBpAd7EupiX5drdwW/wl8JhhOB/qGcf8AhgM7gQYiM1EAAALqSURBVMyofeLWsO4byXzkPx3Ic/cd7l4GrAQWJjimNufu+9z9zWD4OPAukZ18IZEPPcH/a4LhhcDjHvE60NfMTgfmAGvc/bC7HwHWAHM78KW0GTPLAhYAPw3GjcjjRp8OqtTdHtXb6WlgVlB/IbDS3UvdfSeQR2Sf6lLMrDdwCfAzAHcvc/ejhHf/SAMyzSwN6AHsI6T7RjIn/+HA7qjxgqAsaQU/S6cAbwBD3H0fRL4ggMFBtYa2SzJtr38FvgJUBeMDgKPuXhGMR7+2mtcdTC8K6ifL9hgDFAI/D5rBfmpmPQnh/uHue4DvAruIJP0iYD0h3TeSOflbjLKk7ddqZqcBvwbucfdjjVWNUeaNlHcpZnYVcNDd10cXx6jqTUxLiu1B5Eh3KvAjd58CnCDSzNOQpN0ewXmNhUSaaoYBPYF5MaqGYt9I5uRfAIyIGs8C9iYolnZlZt2IJP5fuPszQfGB4Oc6wf+DQXlD2yVZttdM4GozyyfS1Hc5kV8CfYOf+lD7tdW87mB6H+AwybM9CoACd38jGH+ayJdBGPePK4Cd7l7o7uXAM8AMQrpvJHPyXweMC87kpxM5YbMqwTG1uaAN8mfAu+7+L1GTVgHVPTJuAZ6LKr856NVxAVAU/OxfDcw2s37BEdLsoKxLcfel7p7l7qOJvOcvu/ungFeA64NqdbdH9Xa6PqjvQfmioMdHNjAOWNtBL6PNuPt+YLeZTQiKZgFbCOf+sQu4wMx6BJ+b6m0Ryn0j4Wec2/OPSM+FbUTOxt+X6Hja6TVeROQn59vAhuBvPpG2yT8A24P//YP6BjwWbJN3gJyoZd1O5ORVHnBbol9bG2ybSznV22cMkQ9oHvArICMo7x6M5wXTx0TNf1+wnbYC8xL9euLYDucBucE+8hsivXVCuX8A3wDeAzYBTxDpsRPKfUO3dxARCaFkbvYREZEGKPmLiISQkr+ISAgp+YuIhJCSv4hICCn5i4iEkJK/iEgI/X9N6hO2IG9JwgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = DoubleQLearningAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 9000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_qlearning(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_qlearning(env, agent) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SARSA($\\lambda $) 算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SARSALambdaAgent(SARSAAgent):\n",
    "    def __init__(self, env, lambd=0.6, beta=1.,\n",
    "            gamma=0.9, learning_rate=0.1, epsilon=.01):\n",
    "        super().__init__(env, gamma=gamma, learning_rate=learning_rate,\n",
    "                epsilon=epsilon)\n",
    "        self.lambd = lambd\n",
    "        self.beta = beta\n",
    "        self.e = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "        \n",
    "    def learn(self, state, action, reward, next_state, done, next_action):\n",
    "        # 更新资格迹\n",
    "        self.e *= (self.lambd * self.gamma)\n",
    "        self.e[state, action] = 1. + self.beta * self.e[state, action]\n",
    "\n",
    "        # 更新价值\n",
    "        u = reward + self.gamma * \\\n",
    "                self.q[next_state, next_action] * (1. - done)\n",
    "        td_error = u - self.q[state, action]\n",
    "        self.q += self.learning_rate * self.e * td_error\n",
    "        if done:\n",
    "            self.e *= 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均回合奖励 = 815 / 100 = 8.15\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAcl0lEQVR4nO3deXxU9b3/8deHkLAEWQMIYUlAREEQMCCiVVQKiAuttVf9uaD1lrrVtre1P5f+Sm1r9dp7662t1su11Pr7taUuraUVLxfr0lo3AoiACwRECKDsIRCyf35/zEmcMJMAmUkmmfN+Ph7zyDnf851zvt9Z3nPme86ZmLsjIiLh0iHVDRARkdan8BcRCSGFv4hICCn8RURCSOEvIhJCHVPdgKORk5PjeXl5qW6GiEi7snz58l3u3jfesnYR/nl5eRQWFqa6GSIi7YqZfdTYMg37iIiEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCoQz/pwq38P7H+2PKy6tqKN5bxv7yKor3lvHksi3U1jb8yeudpRWUlFVRWV3LR7sP1pcfqqzh9Q27qayupbS8ik/2lwOweXcZldW19fW2lxziYEV1zLaramrjbg+gptb5cNfBmPLD/fea7ewoLW9QtnXfIQ5V1sRsa83WEnbsL6e8qoanCrfg7ny0+yB/X78zZr17Dlby7rb9lJZXNbrtHfvL2X/Y8pKyKnaWVjR6n3j92rKnjPKqGrbtO0RZZeRxqqyO/9iUllfx7MqtABTtKGXDzgMx2yht4rmMdrCimpfe38HDLxWxr6yyvnzXgYoG8/vLq1iztYSSQw37WlUTeT24O0U7GrbD3XmycAsV1ZHnobbWWfdJKZt3lwVtP0BtbaROeVWkTsmhKv709lYqq2vr65Ucqop5fg9XvLes/nE73D+KdrFx5wEWr97OnoOVcevUqal1dpSWs3LzXgBe37CbdZ+UxtSLbl9TNu06yI8Wv1ffv5awZU8ZL3+wo8HjX7hpD+9t389TUY9tba3z5LIt7NhfznPvbI9Zz4rNe1m7raR+/qX3d1C8N9LHmlrn98s2s2LzXlZs3svBimpKyqp4fcNuIPL6+Ljk0+eoeG8Z+8oqWbF5LyWHqjhUWVO/Log896n6WX1rD7/nX1BQ4Mm6yOtgRTWj5y0B4PYZI5mU35v9h6ronZ3FT5au4+/rd8Xcp3d21hHfLK0hOyuDoX2yeXd77AdXW5bTrRO7DjT+IXBKbnfWbE1un2aM7s+Bimr+UbQ7qettT3pnZzH/mtO47NHXj6r+tJP78+62EraVNP0B0xoyOhg1h31Yf2HCIGpqa3n27W31ZZ06dqAiaucqmTIzjKqalsnHnl0z2VfW+M7UF08bxFPLiwE4d2RffnX9pGZtx8yWu3tB3GUp+9Qxmwn8FMgAHnP3+xurm8zwf3p5Md96alVS1iUi0ho23X9hs+7XVPinZNjHzDKAh4ELgFHAlWY2qjW2ffcfV7fGZkRE2rRUjflPAorcfaO7VwILgdmtseGW+orYngzs0TnVTTgqlxcMPub7XHDK8fTsmtno8uO7H13fZ48b2GB+/JCejda955LRR9e4VpDTrVPc8rlnD0vaNm45d3iTy6+bkhdTdvKA7g3mLxwzoH569MDuh1dvdT++bGxMWZ/srGav7+apTT9GAE9+5QxyunViTG6P+rLT83s3qNM5swNF917Q7HY0JVU/7JYLbImaLwZOj65gZnOBuQBDhgxpvZa1U//3hklc88u3mqxz1gk5PHj5OKpqaply/4sAfG7cQB68fByPvLyBHy/5oEH9R6+ewI3/b0X9/C+umsC5J/Xjtt+tZF9ZFZ+fkMuY3B6cErx4t+wp4zMPvMTNU4dz6YRcpv3kb3HbMW5wT75z4cls3HmQbz/zDqfkdscwbjt/BJ8d1Z+8O54D4F8vG8s9s0fzxsbdVNU4b27czWOvfsh1U/J4/LVNcdd936VjqKypZdK9f+WHnzuF3QcqefCFdQD86vqJnDuyH3l3PMe1Zwzlidcjv3n1238+nSkn5LBx5wHO+/dXAPjpFeP56RXj69vyx5vPrJ8GGNqnK49dW0Cv7CxyunUiu1NHBvfqQkFeb4bftZhhOdls3HWQywsG8/vCLZw8oDvvNXKs5rbzTmBHaQULl336lrjhrHx++eqH5PbswtZ9h+rL75p1EnPPjgTL//qvN3gtOND492+fy+DeXWPWXdfmu2adzFfOHsZpP3yBeReP4sT+x3HVY28C8MxNU6iqqaW0vJovP/Hp8OoZw/rw+sbdPHTleL7/57XsOlDJqnnT2VlawcMvbaivt+n+C+u3s/7eC8jM6EDf4zo1eD09c9MZdM1qGDeHHl/G2SNyuO7M/Pr7P/GlSVy74NPX8ePXT2TqyH4N+gIwol831gcHdj+8bxYvvLeD807qx4NL17F49XY2Rp1IcOn4XP4QnBgAMGV4H17bsJv/uHwcg3t3Zf0npXyxYDAPLl1HRXUtS75xNn2yszAzAPaVVXLxz19ly55Pn4fnbjuL0QN7kHfHc4wd1IMumRm8+eGe+uXXTcnj8omDOefHLwPwl6+exUU/e7VB/yfl96bwO9OAyMkmb2zczdSR/bj9qVU8tbyYoX268srt58Y8p8mSkjF/M/siMMPd/zmYvwaY5O5fjVc/GWP+n3/kH3TNykiLA4B/uuVM3t6yj3mL1taX/fnWs7j456/Grf/sLWeSYcaYQZ/uYfzurc2cf1I/+kXtCdfUOis3760/QPj+D2ayv7yKJWs+5poz8o6qbdU1tXTMiHyhrHuzbrr/QlZs3ktFVS1PLy/mx5eNpUOHyBtr3SelHN+jM907f7q3vudgJSWHqsjPyW6w7pJDVXxt4Uoe+MJYCj/ay82/WcHTN57Bys37uHbKUDp1zKivW1VTS8cOhjv8/KUiVm7ey4LrJmIWOZDYwSD/zsUAbPzRrPr23PHMO1w0diBnjcgBYNmmPVRV1zLlhByqa2q59bcr+cJpg/jsqP6NPgZFO0rp260z2Z0y6g9cZnSw+jD5R9EuMjM6cNrQXrhHln28v5wbHi/krlkns2n3Qa6ePJQ1W0sY3rcbmRnG+x+XkpeTTXZWRv16SsurWF1cwsZdkfrxfG3hSl58fwervzej/vmpa8s7xfsY0e84umRFHrfK6lq++rsVfGv6SPJzshu0ee/BSvaWVTKsbzeqa2r5xpOrOKFvNz4pLedHnx/DsDufY86UPOZdHPkW5O78+rVNDOjZhedXb+fBy8fVryue14p2sXprCV85ZziPvrKBNVtL+O5Foxq8Pk/8zvNcdtogxg/uyZQTcvho90E6dujApMP2lneWVjDx3he48ZzhnD6sN6cN7cXThcWM6N+Nxas/5r5LxzR4ndapOxus7rVwuD+v2katO/2O68wZw/sAkfdMXW0HPtx1kF5dM+kTfAOrrI68Ds3gidc/qn/PXnzqQH525fi42zlQUc3CtzZzw1n5TT5mR6PNHfA1szOA77n7jGD+TgB3vy9e/WSEf/ReQ1vwb188tVkHnpd8/WxGHn8c0LBP8cK/7iylwu9Ma3Q4IJ7o0E5EstbTUtp6+yT9tPZrrs0d8AWWASPMLN/MsoArgEUpaktKTAn2HI5V9M7KScGHABB3nHvysMgeUefMjJhlTXnkqgnNbl+0l781lRe/eU7C6xFJJ0d73KmlpWTM392rzexWYAmRUz0XuPvaI9ytTbv386cwqFdX5ixofNz9wrED4l5UAjBrzPHcfeEolq79mO/9+d36Mnc484QcvvPsGoAGXwOfvPEMPikpp7rWGdy7Kw9dOZ7ivWU88N+Rsdaf/NM4vjGtjG6dju1pnjVmALOiDsg1V95hwzZtTd2Ytkhr+dvt59KjS+MnJLSmlP0nL3dfDCxO1faT7arTh/LKuoZXx/5+7mQun/8GEDnb4q5ZJ/PcO89xSm53Dh9se+Sq0wC47sz8+vCvK4u+sjK3Z5f66e6dMxuMlV9yauQMlYl5vVm7tYTOmRmM6P/ptwNp6FfXT6S0PP7VsCItYUif2IPyqdIu/o1je1V3sGr8kJ7cPmMkAB/8cCYZZhysaPwy9xmj+/O5cbn187XBcZkT+3c7qiGciXm9mZjX+4j1wq5zZsYxD4mJpAuFfwvKz8lm6TfOJj8nu/7MgrozUnp07cBfv3kOX36ikPGDezW4339e0/D4TN0xeSOxI/8iInUU/klw6uDGLwBqathleN9uvPjNqUdcf334K/tFJEkU/gm68Zzh9UM6dUb2P47/vOa0pG3DgyMEiZ7zKyJSJ5Q/6ZxMHTsYGYddFNKve6eknukytE9kXV87f0TS1iki4aY9/ySaMKQnQ/t05ZvTRx658jHo1qmjLkQSkaRS+CfRcZ0zW/S3OEREkkXDPiIiIaTwT5DHXK4lItL2KfxFREJI4Z+gz4/PPXIlEZE2RuGfoLrTMEVE2hOFv4hICCn8E6RrbkWkPVL4i4iEkMI/Qfq9HRFpjxT+IiIhpPBPkPb7RaQ9UviLiISQwj9BGvIXkfZI4S8iEkIK/wTpbB8RaY8U/iIiIaTwFxEJIYW/iEgIKfxFREJI4S8iEkIKfxGREFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOF/DCbl9051E0REkkLhLyISQgmFv5n92MzeN7N3zOyPZtYzatmdZlZkZh+Y2Yyo8plBWZGZ3ZHI9kVEpHkS3fNfCpzi7mOBdcCdAGY2CrgCGA3MBB4xswwzywAeBi4ARgFXBnVFRKQVJRT+7v4/7l4dzL4BDAqmZwML3b3C3T8EioBJwa3I3Te6eyWwMKjbYtydDTsPtOQmRETanWSO+X8JeD6YzgW2RC0rDsoaK49hZnPNrNDMCnfu3NnsRj2zYivn//srzb6/iEg66nikCmb2AnB8nEV3u/ufgjp3A9XAb+ruFqe+E//DxuNt193nA/MBCgoK4tY5Gqu27GvuXUVE0tYRw9/dpzW13MzmABcB57t7XUgXA4Ojqg0CtgXTjZW3CP2LXRGRWIme7TMT+N/AJe5eFrVoEXCFmXUys3xgBPAWsAwYYWb5ZpZF5KDwokTaICIix+6Ie/5H8HOgE7DUIrvYb7j7je6+1syeBN4lMhx0i7vXAJjZrcASIANY4O5rE2xDk7TjLyISK6Hwd/cTmlh2L3BvnPLFwOJEtnssTOM+IiIxdIWviEgIKfxFREIo7cNfoz4iIrHSPvxFRCSWwj8Bs8bEu/ZNRKTtS/vwN53sKSISI/3DvwWzPzsr0cskRERSI+3Tq6Wy/9szR3L15KEttHYRkZaV9uHfUm6e2uj1bSIibZ6GfUREQigE4a/0FxE5XNqHv4iIxFL4i4iEUNof8E32oM/PrhzPe9v3J3mtIiKtK+3DP9npf/GpA7n41IHJXamISCtL+2EfXeErIhIr/cNf2S8iEiP9wz/VDRARaYPSP/yV/iIiMdI//LXvLyISI/3DX9kvIhIj/cM/1Q0QEWmD0j/8tesvIhIjBOGf6haIiLQ96R/+zRz42XT/hUluiYhI25H24d9Be/4iIjHSPvw17CMiEivtw19ERGKlffjrbB8RkVhpH/4iIhJL4X8Uvj97dKqbICKSVGkf/mWV1QmvY3jfbkloiYhI25H24X+g/Mjh/9Zd57dCS0RE2o60D/+j0a9751Q3QUSkVaV9+Hucsie+NKnV2yEi0pakffjHO9Hz7BP7JrwOEZH2LP3DX+f5i4jESEr4m9m3zMzNLCeYNzN7yMyKzOwdM5sQVXeOma0PbnOSsX0RETk2HRNdgZkNBj4LbI4qvgAYEdxOB34BnG5mvYF5QAGR4fjlZrbI3fcm2o7GuMcb9RcRCbdk7Pk/CHybhsdWZwNPeMQbQE8zGwDMAJa6+54g8JcCM5PQhhaV1THyMHXvnJniloiIJEdCe/5mdgmw1d1XHTa2ngtsiZovDsoaK4+37rnAXIAhQ4Yk0sZm37fO+CG9+D8XjeLS8XGbKiLS7hwx/M3sBeD4OIvuBu4Cpse7W5wyb6I8ttB9PjAfoKCgIKVjNwbccFZ+KpsgIpJURwx/d58Wr9zMxgD5QN1e/yBghZlNIrJHPziq+iBgW1A+9bDyl5vRbhERSUCzx/zdfbW793P3PHfPIxLsE9z9Y2ARcG1w1s9koMTdtwNLgOlm1svMehH51rAk8W6IiMixaKnz/BcDG4Ei4L+AmwHcfQ/wA2BZcPt+UNbqrpuS1+iyeRePar2GiIikQMKnetYJ9v7rph24pZF6C4AFydpucw3rmx23PLdnF64/U+P7IpLe0v4KXxERiaXwDzx7y5mNLtMvRIhIulH4B/pkZ6W6CSIirUbhfxjt5YtIGCj8RURCSOEvIhJCCv9A3+M60cHg9hkj68sunaDf8hGR9JS08/zbu86ZGWy878IGZQ98YSzzLh6tfwgjImlH4d+Ejhkd6NFFX45EJP2kfbI9/tqmVDdBRKTNSfvwFxGRWKEP/6snN/8fxYiItFehDX8dwhWRMAtt+OvfuotImIU2/EVEwiy04a9hHxEJs9CGv4Z9RCTMQhv+IiJhFtrw17CPiIRZaMNfRCTMFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCCn8RkRBS+IuIhJDCX0QkhEIf/q4f+RGREAp9+IuIhFHow9/0Iz8iEkKhD38N+4hIGIU+/EVEwij04a9hHxEJo9CHv4Z9RCSMQh/+IiJhFPrw17CPiIRRwuFvZl81sw/MbK2ZPRBVfqeZFQXLZkSVzwzKiszsjkS3nygN+4hIGHVM5M5mdi4wGxjr7hVm1i8oHwVcAYwGBgIvmNmJwd0eBj4LFAPLzGyRu7+bSDtEROTYJBT+wE3A/e5eAeDuO4Ly2cDCoPxDMysCJgXLitx9I4CZLQzqpiz8NewjImGU6LDPicBnzOxNM3vFzCYG5bnAlqh6xUFZY+UxzGyumRWaWeHOnTsTbGbjNOwjImF0xD1/M3sBOD7OoruD+/cCJgMTgSfNbBgQb3/aif9hEzd+3X0+MB+goKBAES0ikkRHDH93n9bYMjO7CfiDuzvwlpnVAjlE9ugHR1UdBGwLphsrFxGRVpLosM+zwHkAwQHdLGAXsAi4wsw6mVk+MAJ4C1gGjDCzfDPLInJQeFGCbWgeDfaLSIglesB3AbDAzNYAlcCc4FvAWjN7ksiB3GrgFnevATCzW4ElQAawwN3XJtgGERE5RgmFv7tXAlc3suxe4N445YuBxYlsV0REEhP6K3xFRMJI4S8iEkIKfxGREFL4i4iEkMJfRCSEQhv+eX26AnDSgO4pbomISOtL9Dz/duszI/ry3G1nMUrhLyIhFNrwBxg9sEeqmyAikhKhHfYREQkzhb+ISAgp/EVEQkjhLyISQgp/EZEQUviLiISQwl9EJIQU/iIiIaTwFxEJIYW/iEgIKfxFREJI4S8iEkIKfxGREFL4i4iEkMJfRCSEFP4iIiGk8BcRCSGFv4hICCn8RURCSOEvIhJCCn8RkRBK6/B391Q3QUSkTUrr8BcRkfgU/iIiIZTW4a9RHxGR+NI6/EVEJD6Fv4hICKV1+GvUR0QkvrQO/3iyMkLXZRGRGAkloZmNM7M3zOxtMys0s0lBuZnZQ2ZWZGbvmNmEqPvMMbP1wW1Ooh1oyuHn+a/67nRWfPezLblJEZF2oWOC938AuMfdnzezWcH8VOACYERwOx34BXC6mfUG5gEFREZllpvZInffm2A7jkqPrpmtsRkRkTYv0TEQB7oH0z2AbcH0bOAJj3gD6GlmA4AZwFJ33xME/lJgZoJtEBGRY5Tonv/XgSVm9m9EPkimBOW5wJaoesVBWWPlMcxsLjAXYMiQIc1qnA74iojEd8TwN7MXgOPjLLobOB/4hrs/Y2b/BPwSmAZYnPreRHlsoft8YD5AQUFBs3K8Vld5iYjEdcTwd/dpjS0zsyeArwWzTwGPBdPFwOCoqoOIDAkVEzkmEF3+8lG39hiVlFW11KpFRNq1RMf8twHnBNPnAeuD6UXAtcFZP5OBEnffDiwBpptZLzPrBUwPylpE9y6fHuA9vnvnltqMiEi7k+iY/5eBn5pZR6CcYIweWAzMAoqAMuB6AHffY2Y/AJYF9b7v7nsSbEOjOmdm1E8P6KnwFxGpk1D4u/urwGlxyh24pZH7LAAWJLLdY3HPJaOZt2hta21ORKRdSPvLXU/J7ZHqJoiItDlpH/4iIhJL4S8iEkIKfxGREEr78M/oELmurFPHtO+qiMhRS/RUzzbv1EE9uO28E7hq8tBUN0VEpM1I+/A3M/5l+shUN0NEpE3RWIiISAgp/EVEQkjhLyISQgp/EZEQUviLiISQwl9EJIQU/iIiIaTwFxEJIfN28H9uzWwn8FECq8gBdiWpOe1F2Poctv6C+hwWifR5qLv3jbegXYR/osys0N0LUt2O1hS2Poetv6A+h0VL9VnDPiIiIaTwFxEJobCE//xUNyAFwtbnsPUX1OewaJE+h2LMX0REGgrLnr+IiERR+IuIhFBah7+ZzTSzD8ysyMzuSHV7EmFmC8xsh5mtiSrrbWZLzWx98LdXUG5m9lDQ73fMbELUfeYE9deb2ZxU9OVomdlgM3vJzN4zs7Vm9rWgPG37bWadzewtM1sV9PmeoDzfzN4M2v97M8sKyjsF80XB8ryodd0ZlH9gZjNS06OjY2YZZrbSzP4SzKd7fzeZ2Woze9vMCoOy1n1du3ta3oAMYAMwDMgCVgGjUt2uBPpzNjABWBNV9gBwRzB9B/CvwfQs4HnAgMnAm0F5b2Bj8LdXMN0r1X1ros8DgAnB9HHAOmBUOvc7aHu3YDoTeDPoy5PAFUH5o8BNwfTNwKPB9BXA74PpUcFrvhOQH7wXMlLdvyb6/S/Ab4G/BPPp3t9NQM5hZa36uk7nPf9JQJG7b3T3SmAhMDvFbWo2d/8bsOew4tnAr4PpXwOfiyp/wiPeAHqa2QBgBrDU3fe4+15gKTCz5VvfPO6+3d1XBNOlwHtALmnc76DtB4LZzODmwHnA00H54X2ueyyeBs43MwvKF7p7hbt/CBQReU+0OWY2CLgQeCyYN9K4v01o1dd1Ood/LrAlar44KEsn/d19O0SCEugXlDfW93b7mARf78cT2RNO634HQyBvAzuIvKE3APvcvTqoEt3++r4Fy0uAPrSvPv8H8G2gNpjvQ3r3FyIf6P9jZsvNbG5Q1qqv63T+B+4Wpyws57U21vd2+ZiYWTfgGeDr7r4/sqMXv2qcsnbXb3evAcaZWU/gj8DJ8aoFf9t1n83sImCHuy83s6l1xXGqpkV/o5zp7tvMrB+w1Mzeb6Jui/Q5nff8i4HBUfODgG0paktL+ST4+kfwd0dQ3ljf291jYmaZRIL/N+7+h6A47fsN4O77gJeJjPP2NLO6nbXo9tf3LVjeg8jwYHvp85nAJWa2icjQ7HlEvgmka38BcPdtwd8dRD7gJ9HKr+t0Dv9lwIjgrIEsIgeHFqW4Tcm2CKg7wj8H+FNU+bXBWQKTgZLga+QSYLqZ9QrOJJgelLVJwVjuL4H33P0nUYvStt9m1jfY48fMugDTiBzreAm4LKh2eJ/rHovLgBc9cjRwEXBFcHZMPjACeKt1enH03P1Odx/k7nlE3qMvuvtVpGl/Acws28yOq5sm8npcQ2u/rlN91Lslb0SOkq8jMmZ6d6rbk2BffgdsB6qIfOLfQGSs86/A+uBv76CuAQ8H/V4NFESt50tEDoYVAdenul9H6PNZRL7GvgO8HdxmpXO/gbHAyqDPa4DvBuXDiIRZEfAU0Cko7xzMFwXLh0Wt6+7gsfgAuCDVfTuKvk/l07N90ra/Qd9WBbe1ddnU2q9r/byDiEgIpfOwj4iINELhLyISQgp/EZEQUviLiISQwl9EJIQU/iIiIaTwFxEJof8PxH7ObUcY7EcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "agent = SARSALambdaAgent(env)\n",
    "\n",
    "# 训练\n",
    "episodes = 5000\n",
    "episode_rewards = []\n",
    "for episode in range(episodes):\n",
    "    episode_reward = play_sarsa(env, agent, train=True)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    \n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "# 测试\n",
    "agent.epsilon = 0. # 取消探索\n",
    "\n",
    "episode_rewards = [play_sarsa(env, agent, train=False) for _ in range(100)]\n",
    "print('平均回合奖励 = {} / {} = {}'.format(sum(episode_rewards),\n",
    "        len(episode_rewards), np.mean(episode_rewards)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
